Compare commits

...

5 Commits

Author SHA1 Message Date
tashiscool
2fb6c7e0a3
Merge 45a2363e6a into c011fb520c 2026-05-07 22:10:16 +02:00
Alexander Piskun
c011fb520c
[Partner Nodes] new NanoBanana2 node with DynamicCombo/Autogrow (#13753)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* feat(api-nodes): new NanoBanana2 node with  DynamicCombo/Autogrow

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat: improved status text on uploading

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat: improved status text on uploading (2)

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-07 12:19:44 -07:00
Alexander Piskun
c945a433ae
fix(api-nodes): fixed price badge for Kling V3 model in the Motion Control node (#13790)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-07 11:55:09 -07:00
Tashdid Khan
45a2363e6a fix: remove hardcoded local paths from MPS FP8 tests
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 21:27:35 -05:00
Tashdid Khan
edd44a6874 fix: add CPU fallback for FP8 quantization on MPS (Apple Silicon)
MPS does not support float8_e4m3fn/float8_e5m2 dtypes. When FP8-quantized
models (FLUX, SD3.5, Wan 2.2, LTX-Video) are loaded on Apple Silicon, the
quantization step crashes with:

  TypeError: Trying to convert Float8_e4m3fn to the MPS backend but it does
  not have support for that dtype.

This adds device-aware fallbacks that move tensors to CPU for the FP8
quantization step only. The rest of inference remains on MPS.

Three code paths are patched:
- comfy/float.py: stochastic_rounding() — also fixes the secondary
  "Placeholder storage has not been allocated on MPS device!" error
  caused by torch.Generator being bound to MPS.
- comfy/float.py: stochastic_round_quantize_nvfp4*() — these create
  float8_e4m3fn block scales internally.
- comfy/quant_ops.py: _TensorCoreFP8LayoutBase.quantize() — the
  ck.quantize_per_tensor_fp8 path also fails on MPS.

Fixes: #6995, #9255, #11626, #11817

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 21:21:31 -05:00
5 changed files with 399 additions and 23 deletions

View File

@ -55,6 +55,11 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
# MPS does not support FP8 dtypes — perform rounding on CPU and return the result there.
on_mps = value.device.type == "mps"
if on_mps:
value = value.cpu()
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
@ -159,6 +164,12 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
# MPS does not support FP8 dtypes used for block scales — perform on CPU.
on_mps = x.device.type == "mps"
if on_mps:
x = x.cpu()
per_tensor_scale = per_tensor_scale.cpu() if isinstance(per_tensor_scale, torch.Tensor) else per_tensor_scale
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
@ -179,6 +190,12 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
# MPS does not support FP8 dtypes used for block scales — perform on CPU.
on_mps = x.device.type == "mps"
if on_mps:
x = x.cpu()
per_tensor_scale = per_tensor_scale.cpu() if isinstance(per_tensor_scale, torch.Tensor) else per_tensor_scale
orig_shape = x.shape
# Handle padding

View File

@ -93,6 +93,12 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
# MPS does not support FP8 dtypes — move to CPU for quantization.
on_mps = tensor.device.type == "mps"
if on_mps:
tensor = tensor.cpu()
scale = scale.cpu()
if stochastic_rounding > 0:
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)

View File

@ -83,13 +83,16 @@ class GeminiImageModel(str, Enum):
async def create_image_parts(
cls: type[IO.ComfyNode],
images: Input.Image,
images: Input.Image | list[Input.Image],
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
if image_limit < 0:
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
total_images = get_number_of_images(images)
# Accept either a single (possibly-batched) tensor or a list of them; share URL budget across all.
images_list: list[Input.Image] = images if isinstance(images, list) else [images]
total_images = sum(get_number_of_images(img) for img in images_list)
if total_images <= 0:
raise ValueError("No images provided to create_image_parts; at least one image is required.")
@ -98,10 +101,18 @@ async def create_image_parts(
# Number of images we'll send as URLs (fileData)
num_url_images = min(effective_max, 10) # Vertex API max number of image links
upload_kwargs: dict = {"wait_label": "Uploading reference images"}
if effective_max > num_url_images:
# Split path (e.g. 11+ images): suppress per-image counter to avoid a confusing dual-fraction label.
upload_kwargs = {
"wait_label": f"Uploading reference images ({num_url_images}+)",
"show_batch_index": False,
}
reference_images_urls = await upload_images_to_comfyapi(
cls,
images,
images_list,
max_images=num_url_images,
**upload_kwargs,
)
for reference_image_url in reference_images_urls:
image_parts.append(
@ -112,15 +123,22 @@ async def create_image_parts(
)
)
)
for idx in range(num_url_images, effective_max):
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=tensor_to_base64_string(images[idx]),
if effective_max > num_url_images:
flat: list[torch.Tensor] = []
for tensor in images_list:
if len(tensor.shape) == 4:
flat.extend(tensor[i] for i in range(tensor.shape[0]))
else:
flat.append(tensor)
for idx in range(num_url_images, effective_max):
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=tensor_to_base64_string(flat[idx]),
)
)
)
)
return image_parts
@ -891,10 +909,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
"9:16",
"16:9",
"21:9",
# "1:4",
# "4:1",
# "8:1",
# "1:8",
],
default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; "
@ -902,12 +916,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
),
IO.Combo.Input(
"resolution",
options=[
# "512px",
"1K",
"2K",
"4K",
],
options=["1K", "2K", "4K"],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
),
IO.Combo.Input(
@ -956,6 +965,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
],
is_api_node=True,
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
is_deprecated=True,
)
@classmethod
@ -1016,6 +1026,197 @@ class GeminiNanoBanana2(IO.ComfyNode):
)
def _nano_banana_2_v2_model_inputs():
return [
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"4:5",
"5:4",
"9:16",
"16:9",
"21:9",
"1:4",
"4:1",
"8:1",
"1:8",
],
default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; "
"if no image is provided, a 16:9 square is usually generated.",
),
IO.Combo.Input(
"resolution",
options=["1K", "2K", "4K"],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
),
IO.Combo.Input(
"thinking_level",
options=["MINIMAL", "HIGH"],
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 15)],
min=0,
),
tooltip="Optional reference image(s). Up to 14 images total.",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
]
class GeminiNanoBanana2V2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiNanoBanana2V2",
display_name="Nano Banana 2",
category="api node/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text prompt describing the image to generate or the edits to apply. "
"Include any constraints, styles, or details the model should follow.",
default="",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"Nano Banana 2 (Gemini 3.1 Flash Image)",
_nano_banana_2_v2_model_inputs(),
),
],
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
"the same response for repeated requests. Deterministic output isn't guaranteed. "
"Also, changing the model or parameter settings, such as the temperature, "
"can cause variations in the response even when you use the same seed value. "
"By default, a random seed value is used.",
),
IO.Combo.Input(
"response_modalities",
options=["IMAGE", "IMAGE+TEXT"],
advanced=True,
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
],
outputs=[
IO.Image.Output(),
IO.String.Output(),
IO.Image.Output(
display_name="thought_image",
tooltip="First image from the model's thinking process. "
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
expr="""
(
$r := $lookup(widgets, "model.resolution");
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
response_modalities: str,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_choice = model["model"]
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
model_id = "gemini-3.1-flash-image-preview"
else:
model_id = model_choice
images = model.get("images") or {}
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if images:
image_tensors: list[Input.Image] = [t for t in images.values() if t is not None]
if image_tensors:
if sum(get_number_of_images(t) for t in image_tensors) > 14:
raise ValueError("The current maximum number of supported images is 14.")
parts.extend(await create_image_parts(cls, image_tensors))
files = model.get("files")
if files is not None:
parts.extend(files)
image_config = GeminiImageConfig(imageSize=model["resolution"])
if model["aspect_ratio"] != "auto":
image_config.aspectRatio = model["aspect_ratio"]
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model_id}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(
await get_image_from_response(response),
get_text_from_response(response),
await get_image_from_response(response, thought=True),
)
class GeminiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1024,6 +1225,7 @@ class GeminiExtension(ComfyExtension):
GeminiImage,
GeminiImage2,
GeminiNanoBanana2,
GeminiNanoBanana2V2,
GeminiInputFiles,
]

View File

@ -2787,11 +2787,15 @@ class MotionControl(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model"]),
expr="""
(
$prices := {"std": 0.07, "pro": 0.112};
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
$prices := {
"kling-v3": {"std": 0.126, "pro": 0.168},
"kling-v2-6": {"std": 0.07, "pro": 0.112}
};
$modelPrices := $lookup($prices, widgets.model);
{"type":"usd","usd": $lookup($modelPrices, widgets.mode), "format":{"suffix":"/second"}}
)
""",
),

147
tests/test_fp8_mps.py Normal file
View File

@ -0,0 +1,147 @@
"""
Tests for FP8 quantization on MPS (Apple Silicon) devices.
MPS does not natively support float8_e4m3fn or float8_e5m2 dtypes.
These tests verify that:
1. FP8 operations correctly fall back to CPU when on MPS.
2. The round-trip (quantize on CPU -> result on original device) is numerically sound.
3. No "Placeholder storage has not been allocated on MPS device!" errors occur.
"""
import pytest
import torch
import comfy.float
from comfy.quant_ops import TensorCoreFP8E4M3Layout, TensorCoreFP8E5M2Layout
# Skip the entire module if MPS is not available
pytestmark = pytest.mark.skipif(
not torch.backends.mps.is_available(),
reason="MPS backend not available"
)
# ── helpers ──────────────────────────────────────────────────────────────────
def _make_mps_tensor(shape=(256, 256), dtype=torch.float32):
return torch.randn(shape, device="mps", dtype=dtype)
# ── Tests for comfy.float ────────────────────────────────────────────────────
class TestStochasticRoundingMPS:
"""Tests for comfy.float.stochastic_rounding on MPS device."""
def test_stochastic_rounding_fp8_e4m3fn_on_mps(self):
"""stochastic_rounding must not crash when input is on MPS and target dtype is float8_e4m3fn."""
x = _make_mps_tensor((64, 64), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
assert result.dtype == torch.float8_e4m3fn
assert result.shape == x.shape
def test_stochastic_rounding_fp8_e5m2_on_mps(self):
"""stochastic_rounding must not crash when input is on MPS and target dtype is float8_e5m2."""
x = _make_mps_tensor((64, 64), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e5m2, seed=42)
assert result.dtype == torch.float8_e5m2
assert result.shape == x.shape
def test_stochastic_rounding_fp8_result_on_cpu(self):
"""Result of FP8 rounding from MPS input should be on CPU (since MPS can't hold FP8)."""
x = _make_mps_tensor((32, 32), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
# FP8 tensors cannot live on MPS, so result must be on CPU
assert result.device.type == "cpu"
def test_stochastic_rounding_non_fp8_still_works(self):
"""Non-FP8 dtypes on MPS must still work as before (no regression)."""
x = _make_mps_tensor((32, 32), dtype=torch.float32)
r16 = comfy.float.stochastic_rounding(x, dtype=torch.float16, seed=0)
assert r16.dtype == torch.float16
assert r16.device.type == "mps"
rbf16 = comfy.float.stochastic_rounding(x, dtype=torch.bfloat16, seed=0)
assert rbf16.dtype == torch.bfloat16
assert rbf16.device.type == "mps"
def test_stochastic_rounding_fp8_numerical_sanity(self):
"""FP8 round-trip (float32 -> fp8 -> float32) should have bounded error."""
x = torch.randn(128, 128, device="mps", dtype=torch.float32)
x_clamped = torch.clamp(x, min=-448, max=448) # FP8 e4m3fn range
fp8 = comfy.float.stochastic_rounding(x_clamped, dtype=torch.float8_e4m3fn, seed=123)
# Convert back to float32 for comparison
reconstructed = fp8.to(torch.float32)
# Max relative error should be bounded (FP8 e4m3fn has ~0.125 relative precision)
x_cpu = x_clamped.cpu()
max_abs_err = (reconstructed - x_cpu).abs().max().item()
# FP8 e4m3fn max value is 448, min subnormal ~0.001953
# For random normal data, error should be well under 1.0
assert max_abs_err < 2.0, f"FP8 round-trip error too large: {max_abs_err}"
class TestManualStochasticRoundMPS:
"""Tests for comfy.float.manual_stochastic_round_to_float8 on MPS device."""
def test_manual_round_fp8_on_mps_tensor(self):
"""stochastic_rounding handles MPS generator internally without 'Placeholder storage' error."""
x = _make_mps_tensor((16, 16), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
assert result.dtype == torch.float8_e4m3fn
class TestNVFP4StochasticRoundMPS:
"""Tests for NVFP4 stochastic rounding on MPS - also creates FP8 tensors internally."""
def test_nvfp4_stochastic_round_on_mps(self):
"""stochastic_round_quantize_nvfp4 creates FP8 block scales internally."""
# NVFP4 requires 2D input with dimensions divisible by 16
x = torch.randn(32, 32, device="mps", dtype=torch.float32)
scale = torch.tensor(1.0, device="mps", dtype=torch.float32)
# This should not crash - internally creates float8_e4m3fn block scales
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(
x, scale, pad_16x=False, seed=42
)
assert qdata.dtype == torch.uint8
# ── Tests for comfy.quant_ops (integration) ──────────────────────────────────
class TestQuantOpsMPS:
"""Tests for the quantization ops layer that calls into comfy.float."""
def test_fp8_layout_quantize_on_mps(self):
"""TensorCoreFP8E4M3Layout.quantize must work with MPS tensors."""
x = _make_mps_tensor((64, 64), dtype=torch.bfloat16)
qdata, params = TensorCoreFP8E4M3Layout.quantize(
x, scale="recalculate", stochastic_rounding=42
)
assert qdata.dtype == torch.float8_e4m3fn
assert params.orig_dtype == torch.bfloat16
def test_fp8_layout_quantize_without_stochastic_on_mps(self):
"""TensorCoreFP8E4M3Layout.quantize with stochastic_rounding=0 uses ck.quantize_per_tensor_fp8."""
x = _make_mps_tensor((64, 64), dtype=torch.bfloat16)
qdata, params = TensorCoreFP8E4M3Layout.quantize(
x, scale="recalculate", stochastic_rounding=0
)
assert qdata.dtype == torch.float8_e4m3fn
def test_fp8_e5m2_layout_quantize_on_mps(self):
"""TensorCoreFP8E5M2Layout.quantize must work with MPS tensors."""
x = _make_mps_tensor((64, 64), dtype=torch.float32)
qdata, params = TensorCoreFP8E5M2Layout.quantize(
x, scale="recalculate", stochastic_rounding=42
)
assert qdata.dtype == torch.float8_e5m2
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])