mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
Merge branch 'master' into v3-dynamic-combo
This commit is contained in:
commit
7a81095476
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
<!-- API_NODE_PR_CHECKLIST: do not remove -->
|
||||||
|
|
||||||
|
## API Node PR Checklist
|
||||||
|
|
||||||
|
### Scope
|
||||||
|
- [ ] **Is API Node Change**
|
||||||
|
|
||||||
|
### Pricing & Billing
|
||||||
|
- [ ] **Need pricing update**
|
||||||
|
- [ ] **No pricing update**
|
||||||
|
|
||||||
|
If **Need pricing update**:
|
||||||
|
- [ ] Metronome rate cards updated
|
||||||
|
- [ ] Auto‑billing tests updated and passing
|
||||||
|
|
||||||
|
### QA
|
||||||
|
- [ ] **QA done**
|
||||||
|
- [ ] **QA not required**
|
||||||
|
|
||||||
|
### Comms
|
||||||
|
- [ ] Informed **Kosinkadink**
|
||||||
58
.github/workflows/api-node-template.yml
vendored
Normal file
58
.github/workflows/api-node-template.yml
vendored
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
name: Append API Node PR template
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types: [opened, reopened, synchronize, ready_for_review]
|
||||||
|
paths:
|
||||||
|
- 'comfy_api_nodes/**' # only run if these files changed
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
inject:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Ensure template exists and append to PR body
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const number = context.payload.pull_request.number;
|
||||||
|
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
|
||||||
|
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
|
||||||
|
|
||||||
|
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
|
||||||
|
|
||||||
|
let templateText;
|
||||||
|
try {
|
||||||
|
const res = await github.rest.repos.getContent({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
path: templatePath,
|
||||||
|
ref: pr.base.ref
|
||||||
|
});
|
||||||
|
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
|
||||||
|
templateText = buf.toString('utf8');
|
||||||
|
} catch (e) {
|
||||||
|
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enforce the presence of the marker inside the template (for idempotence)
|
||||||
|
if (!templateText.includes(marker)) {
|
||||||
|
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the PR already contains the marker, do not append again.
|
||||||
|
const body = pr.body || '';
|
||||||
|
if (body.includes(marker)) {
|
||||||
|
core.info('Template already present in PR body; nothing to inject.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
|
||||||
|
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
|
||||||
|
core.notice('API Node template appended to PR description.');
|
||||||
17
.github/workflows/release-stable-all.yml
vendored
17
.github/workflows/release-stable-all.yml
vendored
@ -43,6 +43,23 @@ jobs:
|
|||||||
test_release: true
|
test_release: true
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
release_nvidia_cu126:
|
||||||
|
permissions:
|
||||||
|
contents: "write"
|
||||||
|
packages: "write"
|
||||||
|
pull-requests: "read"
|
||||||
|
name: "Release NVIDIA cu126"
|
||||||
|
uses: ./.github/workflows/stable-release.yml
|
||||||
|
with:
|
||||||
|
git_tag: ${{ inputs.git_tag }}
|
||||||
|
cache_tag: "cu126"
|
||||||
|
python_minor: "12"
|
||||||
|
python_patch: "10"
|
||||||
|
rel_name: "nvidia"
|
||||||
|
rel_extra_name: "_cu126"
|
||||||
|
test_release: true
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
release_amd_rocm:
|
release_amd_rocm:
|
||||||
permissions:
|
permissions:
|
||||||
contents: "write"
|
contents: "write"
|
||||||
|
|||||||
168
QUANTIZATION.md
Normal file
168
QUANTIZATION.md
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
# The Comfy guide to Quantization
|
||||||
|
|
||||||
|
|
||||||
|
## How does quantization work?
|
||||||
|
|
||||||
|
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
||||||
|
|
||||||
|
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
||||||
|
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
||||||
|
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
||||||
|
|
||||||
|
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
||||||
|
|
||||||
|
```
|
||||||
|
absmax = max(abs(tensor))
|
||||||
|
scale = amax / max_dynamic_range_low_precision
|
||||||
|
|
||||||
|
# Quantization
|
||||||
|
tensor_q = (tensor / scale).to(low_precision_dtype)
|
||||||
|
|
||||||
|
# De-Quantization
|
||||||
|
tensor_dq = tensor_q.to(fp16) * scale
|
||||||
|
|
||||||
|
tensor_dq ~ tensor
|
||||||
|
```
|
||||||
|
|
||||||
|
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
||||||
|
|
||||||
|
|
||||||
|
## Quantization in Comfy
|
||||||
|
|
||||||
|
```
|
||||||
|
QuantizedTensor (torch.Tensor subclass)
|
||||||
|
↓ __torch_dispatch__
|
||||||
|
Two-Level Registry (generic + layout handlers)
|
||||||
|
↓
|
||||||
|
MixedPrecisionOps + Metadata Detection
|
||||||
|
```
|
||||||
|
|
||||||
|
### Representation
|
||||||
|
|
||||||
|
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
||||||
|
|
||||||
|
A `Layout` class defines how a specific quantization format behaves:
|
||||||
|
- Required parameters
|
||||||
|
- Quantize method
|
||||||
|
- De-Quantize method
|
||||||
|
|
||||||
|
```python
|
||||||
|
from comfy.quant_ops import QuantizedLayout
|
||||||
|
|
||||||
|
class MyLayout(QuantizedLayout):
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, **kwargs):
|
||||||
|
# Convert to quantized format
|
||||||
|
qdata = ...
|
||||||
|
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
||||||
|
return qdata, params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||||
|
return qdata.to(orig_dtype) * scale
|
||||||
|
```
|
||||||
|
|
||||||
|
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
||||||
|
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
||||||
|
|
||||||
|
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
||||||
|
```python
|
||||||
|
from comfy.quant_ops import register_layout_op
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
||||||
|
def my_linear(func, args, kwargs):
|
||||||
|
# Extract tensors, call optimized kernel
|
||||||
|
...
|
||||||
|
```
|
||||||
|
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
||||||
|
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
||||||
|
|
||||||
|
|
||||||
|
### Mixed Precision
|
||||||
|
|
||||||
|
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
||||||
|
|
||||||
|
**Architecture:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
|
_layer_quant_config = {} # Maps layer names to quantization configs
|
||||||
|
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key mechanism:**
|
||||||
|
|
||||||
|
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
||||||
|
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
||||||
|
- If the layer name **is** in `_layer_quant_config`:
|
||||||
|
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
||||||
|
- Load associated quantization parameters (scales, block_size, etc.)
|
||||||
|
|
||||||
|
**Why it's needed:**
|
||||||
|
|
||||||
|
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
||||||
|
|
||||||
|
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
||||||
|
|
||||||
|
|
||||||
|
## Checkpoint Format
|
||||||
|
|
||||||
|
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
||||||
|
|
||||||
|
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
||||||
|
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
||||||
|
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
||||||
|
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
||||||
|
|
||||||
|
### Scaling Parameters details
|
||||||
|
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
||||||
|
- **weight_scale**: quantization scalers for the weights
|
||||||
|
- **weight_scale_2**: global scalers in the context of double scaling
|
||||||
|
- **pre_quant_scale**: scalers used for smoothing salient weights
|
||||||
|
- **input_scale**: quantization scalers for the activations
|
||||||
|
|
||||||
|
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||||
|
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||||
|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||||
|
|
||||||
|
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||||
|
|
||||||
|
### Quantization Metadata
|
||||||
|
|
||||||
|
The metadata stored alongside the checkpoint contains:
|
||||||
|
- **format_version**: String to define a version of the standard
|
||||||
|
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"_quantization_metadata": {
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": {
|
||||||
|
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||||
|
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||||
|
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Creating Quantized Checkpoints
|
||||||
|
|
||||||
|
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
||||||
|
|
||||||
|
### Weight Quantization
|
||||||
|
|
||||||
|
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
||||||
|
|
||||||
|
### Calibration (for Activation Quantization)
|
||||||
|
|
||||||
|
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
||||||
|
|
||||||
|
1. **Collect statistics**: Run inference on N representative samples
|
||||||
|
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
||||||
|
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||||
|
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||||
|
|
||||||
|
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||||
@ -173,7 +173,7 @@ There is a portable standalone build for Windows that should work for running on
|
|||||||
|
|
||||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
||||||
|
|
||||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
|
||||||
|
|
||||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||||
|
|
||||||
@ -183,7 +183,9 @@ Update your Nvidia drivers if it doesn't start.
|
|||||||
|
|
||||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||||
|
|
||||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
|
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||||
|
|
||||||
|
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||||
|
|
||||||
#### How do I share models between another UI and ComfyUI?
|
#### How do I share models between another UI and ComfyUI?
|
||||||
|
|
||||||
@ -221,7 +223,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
|||||||
|
|
||||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||||
|
|
||||||
|
|
||||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from comfy.ldm.flux.math import attention
|
|
||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
QKNorm,
|
|
||||||
SelfAttention,
|
|
||||||
ModulationOut,
|
ModulationOut,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: remove this in a few months
|
||||||
|
SingleStreamBlock = None
|
||||||
|
DoubleStreamBlock = None
|
||||||
|
|
||||||
|
|
||||||
class ChromaModulationOut(ModulationOut):
|
class ChromaModulationOut(ModulationOut):
|
||||||
@ -48,124 +48,6 @@ class Approximator(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.img_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
||||||
nn.GELU(approximate="tanh"),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.txt_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
||||||
nn.GELU(approximate="tanh"),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
|
||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
|
||||||
|
|
||||||
# prepare image for attention
|
|
||||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
|
||||||
|
|
||||||
# prepare txt for attention
|
|
||||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
|
||||||
|
|
||||||
# run actual attention
|
|
||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
|
||||||
torch.cat((txt_v, img_v), dim=2),
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
|
||||||
|
|
||||||
# calculate the img bloks
|
|
||||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
|
||||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
|
||||||
|
|
||||||
# calculate the txt bloks
|
|
||||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
|
||||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
|
||||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
|
||||||
|
|
||||||
return img, txt
|
|
||||||
|
|
||||||
|
|
||||||
class SingleStreamBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
A DiT block with parallel linear layers as described in
|
|
||||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
qk_scale: float = None,
|
|
||||||
dtype=None,
|
|
||||||
device=None,
|
|
||||||
operations=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_dim = hidden_size
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = hidden_size // num_heads
|
|
||||||
self.scale = qk_scale or head_dim**-0.5
|
|
||||||
|
|
||||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
||||||
# qkv and mlp_in
|
|
||||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
|
||||||
# proj and mlp_out
|
|
||||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
|
||||||
mod = vec
|
|
||||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
||||||
q, k = self.norm(q, k, v)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
|
||||||
x.addcmul_(mod.gate, output)
|
|
||||||
if x.dtype == torch.float16:
|
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class LastLayer(nn.Module):
|
class LastLayer(nn.Module):
|
||||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -11,12 +11,12 @@ import comfy.ldm.common_dit
|
|||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
EmbedND,
|
EmbedND,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
|
DoubleStreamBlock,
|
||||||
|
SingleStreamBlock,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
|
||||||
LastLayer,
|
LastLayer,
|
||||||
SingleStreamBlock,
|
|
||||||
Approximator,
|
Approximator,
|
||||||
ChromaModulationOut,
|
ChromaModulationOut,
|
||||||
)
|
)
|
||||||
@ -90,6 +90,7 @@ class Chroma(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
|
modulation=False,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -98,7 +99,7 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
self.single_blocks = nn.ModuleList(
|
self.single_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
|
||||||
for _ in range(params.depth_single_blocks)
|
for _ in range(params.depth_single_blocks)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,12 +10,10 @@ from torch import Tensor, nn
|
|||||||
from einops import repeat
|
from einops import repeat
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
|
||||||
|
|
||||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||||
from comfy.ldm.chroma.layers import (
|
from comfy.ldm.chroma.layers import (
|
||||||
DoubleStreamBlock,
|
|
||||||
SingleStreamBlock,
|
|
||||||
Approximator,
|
Approximator,
|
||||||
)
|
)
|
||||||
from .layers import (
|
from .layers import (
|
||||||
@ -89,7 +87,6 @@ class ChromaRadiance(Chroma):
|
|||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.double_blocks = nn.ModuleList(
|
self.double_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DoubleStreamBlock(
|
DoubleStreamBlock(
|
||||||
@ -97,6 +94,7 @@ class ChromaRadiance(Chroma):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
|
modulation=False,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -109,6 +107,7 @@ class ChromaRadiance(Chroma):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
modulation=False,
|
||||||
dtype=dtype, device=device, operations=operations,
|
dtype=dtype, device=device, operations=operations,
|
||||||
)
|
)
|
||||||
for _ in range(params.depth_single_blocks)
|
for _ in range(params.depth_single_blocks)
|
||||||
|
|||||||
@ -130,13 +130,17 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
self.modulation = modulation
|
||||||
|
|
||||||
|
if self.modulation:
|
||||||
|
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
@ -147,7 +151,9 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
if self.modulation:
|
||||||
|
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
@ -160,46 +166,65 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
if self.modulation:
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
else:
|
||||||
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
del img_modulated
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
del img_qkv
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
del txt_modulated
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
del txt_qkv
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
if self.flipped_img_txt:
|
if self.flipped_img_txt:
|
||||||
|
q = torch.cat((img_q, txt_q), dim=2)
|
||||||
|
del img_q, txt_q
|
||||||
|
k = torch.cat((img_k, txt_k), dim=2)
|
||||||
|
del img_k, txt_k
|
||||||
|
v = torch.cat((img_v, txt_v), dim=2)
|
||||||
|
del img_v, txt_v
|
||||||
# run actual attention
|
# run actual attention
|
||||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
attn = attention(q, k, v,
|
||||||
torch.cat((img_k, txt_k), dim=2),
|
|
||||||
torch.cat((img_v, txt_v), dim=2),
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||||
else:
|
else:
|
||||||
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
|
del txt_q, img_q
|
||||||
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
|
del txt_k, img_k
|
||||||
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
del txt_v, img_v
|
||||||
# run actual attention
|
# run actual attention
|
||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
attn = attention(q, k, v,
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
|
||||||
torch.cat((txt_v, img_v), dim=2),
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
|
del img_attn
|
||||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
|
del txt_attn
|
||||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
@ -220,6 +245,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
qk_scale: float = None,
|
qk_scale: float = None,
|
||||||
|
modulation=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None
|
||||||
@ -242,19 +268,29 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
if modulation:
|
||||||
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.modulation = None
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
if self.modulation:
|
||||||
|
mod, _ = self.modulation(vec)
|
||||||
|
else:
|
||||||
|
mod = vec
|
||||||
|
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
del qkv
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
del q, k, v
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
mlp = self.mlp_act(mlp)
|
||||||
|
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|||||||
@ -7,7 +7,8 @@ import comfy.model_management
|
|||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||||
q, k = apply_rope(q, k, pe)
|
if pe is not None:
|
||||||
|
q, k = apply_rope(q, k, pe)
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|||||||
37
comfy/ops.py
37
comfy/ops.py
@ -77,7 +77,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
# will add async-offload support to your cast and improve performance.
|
# will add async-offload support to your cast and improve performance.
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = input.dtype
|
if isinstance(input, QuantizedTensor):
|
||||||
|
dtype = input._layout_params["orig_dtype"]
|
||||||
|
else:
|
||||||
|
dtype = input.dtype
|
||||||
if bias_dtype is None:
|
if bias_dtype is None:
|
||||||
bias_dtype = dtype
|
bias_dtype = dtype
|
||||||
if device is None:
|
if device is None:
|
||||||
@ -534,18 +537,7 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor
|
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||||
|
|
||||||
QUANT_FORMAT_MIXINS = {
|
|
||||||
"float8_e4m3fn": {
|
|
||||||
"dtype": torch.float8_e4m3fn,
|
|
||||||
"layout_type": "TensorCoreFP8Layout",
|
|
||||||
"parameters": {
|
|
||||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
|
||||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class MixedPrecisionOps(disable_weight_init):
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
_layer_quant_config = {}
|
_layer_quant_config = {}
|
||||||
@ -596,23 +588,24 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
if quant_format is None:
|
if quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
qconfig = QUANT_ALGOS[quant_format]
|
||||||
self.layout_type = mixin["layout_type"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
|
||||||
scale_key = f"{prefix}weight_scale"
|
weight_scale_key = f"{prefix}weight_scale"
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': state_dict.pop(scale_key, None),
|
'scale': state_dict.pop(weight_scale_key, None),
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||||
|
'block_size': qconfig.get("group_size", None),
|
||||||
}
|
}
|
||||||
if layout_params['scale'] is not None:
|
if layout_params['scale'] is not None:
|
||||||
manually_loaded_keys.append(scale_key)
|
manually_loaded_keys.append(weight_scale_key)
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
for param_name, param_value in mixin["parameters"].items():
|
for param_name in qconfig["parameters"]:
|
||||||
param_key = f"{prefix}{param_name}"
|
param_key = f"{prefix}{param_name}"
|
||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
@ -643,7 +636,7 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
getattr(self, 'input_scale', None) is not None and
|
getattr(self, 'input_scale', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -74,6 +74,12 @@ def _copy_layout_params(params):
|
|||||||
new_params[k] = v
|
new_params[k] = v
|
||||||
return new_params
|
return new_params
|
||||||
|
|
||||||
|
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
||||||
|
for k, v in src.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
dst[k].copy_(v, non_blocking=non_blocking)
|
||||||
|
else:
|
||||||
|
dst[k] = v
|
||||||
|
|
||||||
class QuantizedLayout:
|
class QuantizedLayout:
|
||||||
"""
|
"""
|
||||||
@ -318,13 +324,13 @@ def generic_to_dtype_layout(func, args, kwargs):
|
|||||||
def generic_copy_(func, args, kwargs):
|
def generic_copy_(func, args, kwargs):
|
||||||
qt_dest = args[0]
|
qt_dest = args[0]
|
||||||
src = args[1]
|
src = args[1]
|
||||||
|
non_blocking = args[2] if len(args) > 2 else False
|
||||||
if isinstance(qt_dest, QuantizedTensor):
|
if isinstance(qt_dest, QuantizedTensor):
|
||||||
if isinstance(src, QuantizedTensor):
|
if isinstance(src, QuantizedTensor):
|
||||||
# Copy from another quantized tensor
|
# Copy from another quantized tensor
|
||||||
qt_dest._qdata.copy_(src._qdata)
|
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||||
qt_dest._layout_type = src._layout_type
|
qt_dest._layout_type = src._layout_type
|
||||||
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
# Copy from regular tensor - just copy raw data
|
# Copy from regular tensor - just copy raw data
|
||||||
qt_dest._qdata.copy_(src)
|
qt_dest._qdata.copy_(src)
|
||||||
@ -336,6 +342,26 @@ def generic_copy_(func, args, kwargs):
|
|||||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.empty_like.default)
|
||||||
|
def generic_empty_like(func, args, kwargs):
|
||||||
|
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
# Create empty tensor with same shape and dtype as the quantized data
|
||||||
|
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
||||||
|
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
||||||
|
|
||||||
|
# Handle device transfer for layout params
|
||||||
|
target_device = kwargs.get('device', new_qdata.device)
|
||||||
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
|
||||||
|
# Update orig_dtype if dtype is specified
|
||||||
|
new_params['orig_dtype'] = hp_dtype
|
||||||
|
|
||||||
|
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# FP8 Layout + Operation Handlers
|
# FP8 Layout + Operation Handlers
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -378,6 +404,13 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
def get_plain_tensors(cls, qtensor):
|
def get_plain_tensors(cls, qtensor):
|
||||||
return qtensor._qdata, qtensor._layout_params['scale']
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
QUANT_ALGOS = {
|
||||||
|
"float8_e4m3fn": {
|
||||||
|
"storage_t": torch.float8_e4m3fn,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
LAYOUTS = {
|
LAYOUTS = {
|
||||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||||
|
|||||||
@ -460,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||||
@ -468,6 +468,7 @@ class SDTokenizer:
|
|||||||
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
||||||
self.end_token = None
|
self.end_token = None
|
||||||
self.min_padding = min_padding
|
self.min_padding = min_padding
|
||||||
|
self.pad_left = pad_left
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
self.tokenizer_adds_end_token = has_end_token
|
self.tokenizer_adds_end_token = has_end_token
|
||||||
@ -522,6 +523,12 @@ class SDTokenizer:
|
|||||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||||
return (embed, leftover)
|
return (embed, leftover)
|
||||||
|
|
||||||
|
def pad_tokens(self, tokens, amount):
|
||||||
|
if self.pad_left:
|
||||||
|
for i in range(amount):
|
||||||
|
tokens.insert(0, (self.pad_token, 1.0, 0))
|
||||||
|
else:
|
||||||
|
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||||
'''
|
'''
|
||||||
@ -600,7 +607,7 @@ class SDTokenizer:
|
|||||||
if self.end_token is not None:
|
if self.end_token is not None:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
self.pad_tokens(batch, remaining_length)
|
||||||
#start new batch
|
#start new batch
|
||||||
batch = []
|
batch = []
|
||||||
if self.start_token is not None:
|
if self.start_token is not None:
|
||||||
@ -614,11 +621,11 @@ class SDTokenizer:
|
|||||||
if self.end_token is not None:
|
if self.end_token is not None:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if min_padding is not None:
|
if min_padding is not None:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
self.pad_tokens(batch, min_padding)
|
||||||
if self.pad_to_max_length and len(batch) < self.max_length:
|
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
self.pad_tokens(batch, self.max_length - len(batch))
|
||||||
if min_length is not None and len(batch) < min_length:
|
if min_length is not None and len(batch) < min_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
self.pad_tokens(batch, min_length - len(batch))
|
||||||
|
|
||||||
if not return_word_ids:
|
if not return_word_ids:
|
||||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||||
|
|||||||
@ -32,6 +32,7 @@ class Llama2Config:
|
|||||||
q_norm = None
|
q_norm = None
|
||||||
k_norm = None
|
k_norm = None
|
||||||
rope_scale = None
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_3BConfig:
|
class Qwen25_3BConfig:
|
||||||
@ -53,6 +54,7 @@ class Qwen25_3BConfig:
|
|||||||
q_norm = None
|
q_norm = None
|
||||||
k_norm = None
|
k_norm = None
|
||||||
rope_scale = None
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_7BVLI_Config:
|
class Qwen25_7BVLI_Config:
|
||||||
@ -74,6 +76,7 @@ class Qwen25_7BVLI_Config:
|
|||||||
q_norm = None
|
q_norm = None
|
||||||
k_norm = None
|
k_norm = None
|
||||||
rope_scale = None
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma2_2B_Config:
|
class Gemma2_2B_Config:
|
||||||
@ -96,6 +99,7 @@ class Gemma2_2B_Config:
|
|||||||
k_norm = None
|
k_norm = None
|
||||||
sliding_attention = None
|
sliding_attention = None
|
||||||
rope_scale = None
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma3_4B_Config:
|
class Gemma3_4B_Config:
|
||||||
@ -118,6 +122,7 @@ class Gemma3_4B_Config:
|
|||||||
k_norm = "gemma3"
|
k_norm = "gemma3"
|
||||||
sliding_attention = [False, False, False, False, False, 1024]
|
sliding_attention = [False, False, False, False, False, 1024]
|
||||||
rope_scale = [1.0, 8.0]
|
rope_scale = [1.0, 8.0]
|
||||||
|
final_norm: bool = True
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||||
@ -366,7 +371,12 @@ class Llama2_(nn.Module):
|
|||||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
|
||||||
|
if config.final_norm:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||||
@ -421,14 +431,16 @@ class Llama2_(nn.Module):
|
|||||||
if i == intermediate_output:
|
if i == intermediate_output:
|
||||||
intermediate = x.clone()
|
intermediate = x.clone()
|
||||||
|
|
||||||
x = self.norm(x)
|
if self.norm is not None:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
if all_intermediate is not None:
|
if all_intermediate is not None:
|
||||||
all_intermediate.append(x.unsqueeze(1).clone())
|
all_intermediate.append(x.unsqueeze(1).clone())
|
||||||
|
|
||||||
if all_intermediate is not None:
|
if all_intermediate is not None:
|
||||||
intermediate = torch.cat(all_intermediate, dim=1)
|
intermediate = torch.cat(all_intermediate, dim=1)
|
||||||
|
|
||||||
if intermediate is not None and final_layer_norm_intermediate:
|
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||||
intermediate = self.norm(intermediate)
|
intermediate = self.norm(intermediate)
|
||||||
|
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
|
|||||||
@ -1,22 +1,229 @@
|
|||||||
from typing import Optional
|
from datetime import date
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
from pydantic import BaseModel, Field
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
class GeminiSafetyCategory(str, Enum):
|
||||||
|
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
|
||||||
|
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
||||||
|
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
||||||
|
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiSafetyThreshold(str, Enum):
|
||||||
|
OFF = "OFF"
|
||||||
|
BLOCK_NONE = "BLOCK_NONE"
|
||||||
|
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
||||||
|
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
||||||
|
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiSafetySetting(BaseModel):
|
||||||
|
category: GeminiSafetyCategory
|
||||||
|
threshold: GeminiSafetyThreshold
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiRole(str, Enum):
|
||||||
|
user = "user"
|
||||||
|
model = "model"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiMimeType(str, Enum):
|
||||||
|
application_pdf = "application/pdf"
|
||||||
|
audio_mpeg = "audio/mpeg"
|
||||||
|
audio_mp3 = "audio/mp3"
|
||||||
|
audio_wav = "audio/wav"
|
||||||
|
image_png = "image/png"
|
||||||
|
image_jpeg = "image/jpeg"
|
||||||
|
image_webp = "image/webp"
|
||||||
|
text_plain = "text/plain"
|
||||||
|
video_mov = "video/mov"
|
||||||
|
video_mpeg = "video/mpeg"
|
||||||
|
video_mp4 = "video/mp4"
|
||||||
|
video_mpg = "video/mpg"
|
||||||
|
video_avi = "video/avi"
|
||||||
|
video_wmv = "video/wmv"
|
||||||
|
video_mpegps = "video/mpegps"
|
||||||
|
video_flv = "video/flv"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiInlineData(BaseModel):
|
||||||
|
data: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
|
||||||
|
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
|
||||||
|
)
|
||||||
|
mimeType: GeminiMimeType | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiPart(BaseModel):
|
||||||
|
inlineData: GeminiInlineData | None = Field(None)
|
||||||
|
text: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiTextPart(BaseModel):
|
||||||
|
text: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiContent(BaseModel):
|
||||||
|
parts: list[GeminiPart] = Field(...)
|
||||||
|
role: GeminiRole = Field(..., examples=["user"])
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiSystemInstructionContent(BaseModel):
|
||||||
|
parts: list[GeminiTextPart] = Field(
|
||||||
|
...,
|
||||||
|
description="A list of ordered parts that make up a single message. "
|
||||||
|
"Different parts may have different IANA MIME types.",
|
||||||
|
)
|
||||||
|
role: GeminiRole = Field(
|
||||||
|
...,
|
||||||
|
description="The identity of the entity that creates the message. "
|
||||||
|
"The following values are supported: "
|
||||||
|
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||||
|
"model: This indicates that the message is generated by the model. "
|
||||||
|
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||||
|
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiFunctionDeclaration(BaseModel):
|
||||||
|
description: str | None = Field(None)
|
||||||
|
name: str = Field(...)
|
||||||
|
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiTool(BaseModel):
|
||||||
|
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiOffset(BaseModel):
|
||||||
|
nanos: int | None = Field(None, ge=0, le=999999999)
|
||||||
|
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiVideoMetadata(BaseModel):
|
||||||
|
endOffset: GeminiOffset | None = Field(None)
|
||||||
|
startOffset: GeminiOffset | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiGenerationConfig(BaseModel):
|
||||||
|
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||||
|
seed: int | None = Field(None)
|
||||||
|
stopSequences: list[str] | None = Field(None)
|
||||||
|
temperature: float | None = Field(1, ge=0.0, le=2.0)
|
||||||
|
topK: int | None = Field(40, ge=1)
|
||||||
|
topP: float | None = Field(0.95, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageConfig(BaseModel):
|
class GeminiImageConfig(BaseModel):
|
||||||
aspectRatio: Optional[str] = None
|
aspectRatio: str | None = Field(None)
|
||||||
|
resolution: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||||
responseModalities: Optional[list[str]] = None
|
responseModalities: list[str] | None = Field(None)
|
||||||
imageConfig: Optional[GeminiImageConfig] = None
|
imageConfig: GeminiImageConfig | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerateContentRequest(BaseModel):
|
class GeminiImageGenerateContentRequest(BaseModel):
|
||||||
contents: list[GeminiContent]
|
contents: list[GeminiContent] = Field(...)
|
||||||
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
generationConfig: GeminiImageGenerationConfig | None = Field(None)
|
||||||
safetySettings: Optional[list[GeminiSafetySetting]] = None
|
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||||
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||||
tools: Optional[list[GeminiTool]] = None
|
tools: list[GeminiTool] | None = Field(None)
|
||||||
videoMetadata: Optional[GeminiVideoMetadata] = None
|
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiGenerateContentRequest(BaseModel):
|
||||||
|
contents: list[GeminiContent] = Field(...)
|
||||||
|
generationConfig: GeminiGenerationConfig | None = Field(None)
|
||||||
|
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||||
|
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||||
|
tools: list[GeminiTool] | None = Field(None)
|
||||||
|
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class Modality(str, Enum):
|
||||||
|
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
|
||||||
|
TEXT = "TEXT"
|
||||||
|
IMAGE = "IMAGE"
|
||||||
|
VIDEO = "VIDEO"
|
||||||
|
AUDIO = "AUDIO"
|
||||||
|
DOCUMENT = "DOCUMENT"
|
||||||
|
|
||||||
|
|
||||||
|
class ModalityTokenCount(BaseModel):
|
||||||
|
modality: Modality | None = None
|
||||||
|
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
|
||||||
|
|
||||||
|
|
||||||
|
class Probability(str, Enum):
|
||||||
|
NEGLIGIBLE = "NEGLIGIBLE"
|
||||||
|
LOW = "LOW"
|
||||||
|
MEDIUM = "MEDIUM"
|
||||||
|
HIGH = "HIGH"
|
||||||
|
UNKNOWN = "UNKNOWN"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiSafetyRating(BaseModel):
|
||||||
|
category: GeminiSafetyCategory | None = None
|
||||||
|
probability: Probability | None = Field(
|
||||||
|
None,
|
||||||
|
description="The probability that the content violates the specified safety category",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiCitation(BaseModel):
|
||||||
|
authors: list[str] | None = None
|
||||||
|
endIndex: int | None = None
|
||||||
|
license: str | None = None
|
||||||
|
publicationDate: date | None = None
|
||||||
|
startIndex: int | None = None
|
||||||
|
title: str | None = None
|
||||||
|
uri: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiCitationMetadata(BaseModel):
|
||||||
|
citations: list[GeminiCitation] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiCandidate(BaseModel):
|
||||||
|
citationMetadata: GeminiCitationMetadata | None = None
|
||||||
|
content: GeminiContent | None = None
|
||||||
|
finishReason: str | None = None
|
||||||
|
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiPromptFeedback(BaseModel):
|
||||||
|
blockReason: str | None = None
|
||||||
|
blockReasonMessage: str | None = None
|
||||||
|
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiUsageMetadata(BaseModel):
|
||||||
|
cachedContentTokenCount: int | None = Field(
|
||||||
|
None,
|
||||||
|
description="Output only. Number of tokens in the cached part in the input (the cached content).",
|
||||||
|
)
|
||||||
|
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
|
||||||
|
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||||
|
None, description="Breakdown of candidate tokens by modality."
|
||||||
|
)
|
||||||
|
promptTokenCount: int | None = Field(
|
||||||
|
None,
|
||||||
|
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
|
||||||
|
)
|
||||||
|
promptTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||||
|
None, description="Breakdown of prompt tokens by modality."
|
||||||
|
)
|
||||||
|
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
|
||||||
|
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiGenerateContentResponse(BaseModel):
|
||||||
|
candidates: list[GeminiCandidate] | None = Field(None)
|
||||||
|
promptFeedback: GeminiPromptFeedback | None = Field(None)
|
||||||
|
usageMetadata: GeminiUsageMetadata | None = Field(None)
|
||||||
|
|||||||
@ -3,8 +3,6 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API
|
|||||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -12,7 +10,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@ -20,18 +18,17 @@ from typing_extensions import override
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
from comfy_api.util import VideoCodec, VideoContainer
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis.gemini_api import (
|
||||||
GeminiContent,
|
GeminiContent,
|
||||||
GeminiGenerateContentRequest,
|
GeminiGenerateContentRequest,
|
||||||
GeminiGenerateContentResponse,
|
GeminiGenerateContentResponse,
|
||||||
GeminiInlineData,
|
|
||||||
GeminiMimeType,
|
|
||||||
GeminiPart,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.gemini_api import (
|
|
||||||
GeminiImageConfig,
|
GeminiImageConfig,
|
||||||
GeminiImageGenerateContentRequest,
|
GeminiImageGenerateContentRequest,
|
||||||
GeminiImageGenerationConfig,
|
GeminiImageGenerationConfig,
|
||||||
|
GeminiInlineData,
|
||||||
|
GeminiMimeType,
|
||||||
|
GeminiPart,
|
||||||
|
GeminiRole,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -57,6 +54,7 @@ class GeminiModel(str, Enum):
|
|||||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||||
gemini_2_5_pro = "gemini-2.5-pro"
|
gemini_2_5_pro = "gemini-2.5-pro"
|
||||||
gemini_2_5_flash = "gemini-2.5-flash"
|
gemini_2_5_flash = "gemini-2.5-flash"
|
||||||
|
gemini_3_0_pro = "gemini-3-pro-preview"
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageModel(str, Enum):
|
class GeminiImageModel(str, Enum):
|
||||||
@ -103,6 +101,16 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
|||||||
Returns:
|
Returns:
|
||||||
List of response parts matching the requested type.
|
List of response parts matching the requested type.
|
||||||
"""
|
"""
|
||||||
|
if response.candidates is None:
|
||||||
|
if response.promptFeedback.blockReason:
|
||||||
|
feedback = response.promptFeedback
|
||||||
|
raise ValueError(
|
||||||
|
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
|
||||||
|
)
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Gemini returned no response candidates. "
|
||||||
|
"Please report to ComfyUI repository with the example of workflow to reproduce this."
|
||||||
|
)
|
||||||
parts = []
|
parts = []
|
||||||
for part in response.candidates[0].content.parts:
|
for part in response.candidates[0].content.parts:
|
||||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||||
@ -272,10 +280,10 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
images: Optional[torch.Tensor] = None,
|
images: torch.Tensor | None = None,
|
||||||
audio: Optional[Input.Audio] = None,
|
audio: Input.Audio | None = None,
|
||||||
video: Optional[Input.Video] = None,
|
video: Input.Video | None = None,
|
||||||
files: Optional[list[GeminiPart]] = None,
|
files: list[GeminiPart] | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
|
||||||
@ -300,7 +308,7 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
data=GeminiGenerateContentRequest(
|
data=GeminiGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(
|
GeminiContent(
|
||||||
role="user",
|
role=GeminiRole.user,
|
||||||
parts=parts,
|
parts=parts,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -308,7 +316,6 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get result output
|
|
||||||
output_text = get_text_from_response(response)
|
output_text = get_text_from_response(response)
|
||||||
if output_text:
|
if output_text:
|
||||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||||
@ -406,7 +413,7 @@ class GeminiInputFiles(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
|
def execute(cls, file: str, GEMINI_INPUT_FILES: list[GeminiPart] | None = None) -> IO.NodeOutput:
|
||||||
"""Loads and formats input files for Gemini API."""
|
"""Loads and formats input files for Gemini API."""
|
||||||
if GEMINI_INPUT_FILES is None:
|
if GEMINI_INPUT_FILES is None:
|
||||||
GEMINI_INPUT_FILES = []
|
GEMINI_INPUT_FILES = []
|
||||||
@ -421,7 +428,7 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="GeminiImageNode",
|
node_id="GeminiImageNode",
|
||||||
display_name="Google Gemini Image",
|
display_name="Nano Banana (Google Gemini Image)",
|
||||||
category="api node/image/Gemini",
|
category="api node/image/Gemini",
|
||||||
description="Edit images synchronously via Google API.",
|
description="Edit images synchronously via Google API.",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -488,8 +495,8 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
images: Optional[torch.Tensor] = None,
|
images: torch.Tensor | None = None,
|
||||||
files: Optional[list[GeminiPart]] = None,
|
files: list[GeminiPart] | None = None,
|
||||||
aspect_ratio: str = "auto",
|
aspect_ratio: str = "auto",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
@ -510,7 +517,7 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||||
data=GeminiImageGenerateContentRequest(
|
data=GeminiImageGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(role="user", parts=parts),
|
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||||
],
|
],
|
||||||
generationConfig=GeminiImageGenerationConfig(
|
generationConfig=GeminiImageGenerationConfig(
|
||||||
responseModalities=["TEXT", "IMAGE"],
|
responseModalities=["TEXT", "IMAGE"],
|
||||||
|
|||||||
@ -11,13 +11,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||||
# get values from args
|
# get values from args
|
||||||
x: torch.Tensor = args[0]
|
|
||||||
transformer_options: dict[str] = args[-1]
|
transformer_options: dict[str] = args[-1]
|
||||||
if not isinstance(transformer_options, dict):
|
if not isinstance(transformer_options, dict):
|
||||||
transformer_options = kwargs.get("transformer_options")
|
transformer_options = kwargs.get("transformer_options")
|
||||||
if not transformer_options:
|
if not transformer_options:
|
||||||
transformer_options = args[-2]
|
transformer_options = args[-2]
|
||||||
easycache: EasyCacheHolder = transformer_options["easycache"]
|
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||||
|
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||||
sigmas = transformer_options["sigmas"]
|
sigmas = transformer_options["sigmas"]
|
||||||
uuids = transformer_options["uuids"]
|
uuids = transformer_options["uuids"]
|
||||||
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||||
@ -82,13 +82,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
|
|
||||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||||
# get values from args
|
# get values from args
|
||||||
x: torch.Tensor = args[0]
|
|
||||||
timestep: float = args[1]
|
timestep: float = args[1]
|
||||||
model_options: dict[str] = args[2]
|
model_options: dict[str] = args[2]
|
||||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||||
if easycache.is_past_end_timestep(timestep):
|
if easycache.is_past_end_timestep(timestep):
|
||||||
return executor(*args, **kwargs)
|
return executor(*args, **kwargs)
|
||||||
# prepare next x_prev
|
# prepare next x_prev
|
||||||
|
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||||
next_x_prev = x
|
next_x_prev = x
|
||||||
input_change = None
|
input_change = None
|
||||||
do_easycache = easycache.should_do_easycache(timestep)
|
do_easycache = easycache.should_do_easycache(timestep)
|
||||||
@ -173,7 +173,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
class EasyCacheHolder:
|
class EasyCacheHolder:
|
||||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
||||||
self.name = "EasyCache"
|
self.name = "EasyCache"
|
||||||
self.reuse_threshold = reuse_threshold
|
self.reuse_threshold = reuse_threshold
|
||||||
self.start_percent = start_percent
|
self.start_percent = start_percent
|
||||||
@ -202,6 +202,7 @@ class EasyCacheHolder:
|
|||||||
self.allow_mismatch = True
|
self.allow_mismatch = True
|
||||||
self.cut_from_start = True
|
self.cut_from_start = True
|
||||||
self.state_metadata = None
|
self.state_metadata = None
|
||||||
|
self.output_channels = output_channels
|
||||||
|
|
||||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||||
return not (timestep[0] > self.end_t).item()
|
return not (timestep[0] > self.end_t).item()
|
||||||
@ -264,7 +265,7 @@ class EasyCacheHolder:
|
|||||||
else:
|
else:
|
||||||
slicing.append(slice(None))
|
slicing.append(slice(None))
|
||||||
batch_slice = batch_slice + slicing
|
batch_slice = batch_slice + slicing
|
||||||
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
|
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||||
@ -283,7 +284,7 @@ class EasyCacheHolder:
|
|||||||
else:
|
else:
|
||||||
slicing.append(slice(None))
|
slicing.append(slice(None))
|
||||||
skip_dim = False
|
skip_dim = False
|
||||||
x = x[slicing]
|
x = x[tuple(slicing)]
|
||||||
diff = output - x
|
diff = output - x
|
||||||
batch_offset = diff.shape[0] // len(uuids)
|
batch_offset = diff.shape[0] // len(uuids)
|
||||||
for i, uuid in enumerate(uuids):
|
for i, uuid in enumerate(uuids):
|
||||||
@ -323,7 +324,7 @@ class EasyCacheHolder:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
||||||
|
|
||||||
|
|
||||||
class EasyCacheNode(io.ComfyNode):
|
class EasyCacheNode(io.ComfyNode):
|
||||||
@ -350,7 +351,7 @@ class EasyCacheNode(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
||||||
@ -358,7 +359,7 @@ class EasyCacheNode(io.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class LazyCacheHolder:
|
class LazyCacheHolder:
|
||||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
||||||
self.name = "LazyCache"
|
self.name = "LazyCache"
|
||||||
self.reuse_threshold = reuse_threshold
|
self.reuse_threshold = reuse_threshold
|
||||||
self.start_percent = start_percent
|
self.start_percent = start_percent
|
||||||
@ -382,6 +383,7 @@ class LazyCacheHolder:
|
|||||||
self.approx_output_change_rates = []
|
self.approx_output_change_rates = []
|
||||||
self.total_steps_skipped = 0
|
self.total_steps_skipped = 0
|
||||||
self.state_metadata = None
|
self.state_metadata = None
|
||||||
|
self.output_channels = output_channels
|
||||||
|
|
||||||
def has_cache_diff(self) -> bool:
|
def has_cache_diff(self) -> bool:
|
||||||
return self.cache_diff is not None
|
return self.cache_diff is not None
|
||||||
@ -456,7 +458,7 @@ class LazyCacheHolder:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
||||||
|
|
||||||
class LazyCacheNode(io.ComfyNode):
|
class LazyCacheNode(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -482,7 +484,7 @@ class LazyCacheNode(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|||||||
39
comfy_extras/nodes_nop.py
Normal file
39
comfy_extras/nodes_nop.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from typing_extensions import override
|
||||||
|
# If you write a node that is so useless that it breaks ComfyUI it will be featured in this exclusive list
|
||||||
|
|
||||||
|
# "native" block swap nodes are placebo at best and break the ComfyUI memory management system.
|
||||||
|
# They are also considered harmful because instead of users reporting issues with the built in
|
||||||
|
# memory management they install these stupid nodes and complain even harder. Now it completely
|
||||||
|
# breaks with some of the new ComfyUI memory optimizations so I have made the decision to NOP it
|
||||||
|
# out of all workflows.
|
||||||
|
class wanBlockSwap(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="wanBlockSwap",
|
||||||
|
category="",
|
||||||
|
description="NOP",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
|
class NopExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
wanBlockSwap
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> NopExtension:
|
||||||
|
return NopExtension()
|
||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.68"
|
__version__ = "0.3.70"
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -2331,6 +2331,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
"nodes_rope.py",
|
"nodes_rope.py",
|
||||||
"nodes_logic.py",
|
"nodes_logic.py",
|
||||||
|
"nodes_nop.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.68"
|
version = "0.3.70"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
@ -24,7 +24,7 @@ lint.select = [
|
|||||||
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||||
|
|
||||||
[tool.pylint]
|
[tool.pylint]
|
||||||
master.py-version = "3.9"
|
master.py-version = "3.10"
|
||||||
master.extension-pkg-allow-list = [
|
master.extension-pkg-allow-list = [
|
||||||
"pydantic",
|
"pydantic",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
|
import time
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -733,6 +734,7 @@ class PromptServer():
|
|||||||
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
|
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
|
||||||
if sensitive_val in extra_data:
|
if sensitive_val in extra_data:
|
||||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||||
|
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
||||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||||
return web.json_response(response)
|
return web.json_response(response)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user