Merge branch 'master' into enable-triton-comfy-kitchen

This commit is contained in:
Silver 2026-03-07 05:46:58 +01:00 committed by GitHub
commit 433e9a2365
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 588 additions and 71 deletions

View File

@ -27,6 +27,7 @@ class AudioEncoderModel():
self.model.eval() self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000 self.model_sample_rate = 16000
comfy.model_management.archive_model_dtypes(self.model)
def load_sd(self, sd): def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

View File

@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import comfy.ops import comfy.ops
import comfy.model_management
import numpy as np import numpy as np
import math import math
@ -81,7 +82,7 @@ class LowPassFilter1d(nn.Module):
_, C, _ = x.shape _, C, _ = x.shape
if self.padding: if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
class UpSample1d(nn.Module): class UpSample1d(nn.Module):
@ -125,7 +126,7 @@ class UpSample1d(nn.Module):
_, C, _ = x.shape _, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate") x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d( x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C
) )
x = x[..., self.pad_left : -self.pad_right] x = x[..., self.pad_left : -self.pad_right]
return x return x
@ -190,7 +191,7 @@ class Snake(nn.Module):
self.eps = 1e-9 self.eps = 1e-9
def forward(self, x): def forward(self, x):
a = self.alpha.unsqueeze(0).unsqueeze(-1) a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
if self.alpha_logscale: if self.alpha_logscale:
a = torch.exp(a) a = torch.exp(a)
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
@ -217,8 +218,8 @@ class SnakeBeta(nn.Module):
self.eps = 1e-9 self.eps = 1e-9
def forward(self, x): def forward(self, x):
a = self.alpha.unsqueeze(0).unsqueeze(-1) a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
b = self.beta.unsqueeze(0).unsqueeze(-1) b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
if self.alpha_logscale: if self.alpha_logscale:
a = torch.exp(a) a = torch.exp(a)
b = torch.exp(b) b = torch.exp(b)
@ -596,7 +597,7 @@ class _STFTFn(nn.Module):
y = y.unsqueeze(1) # (B, 1, T) y = y.unsqueeze(1) # (B, 1, T)
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
y = F.pad(y, (left_pad, 0)) y = F.pad(y, (left_pad, 0))
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2 n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:] real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real ** 2 + imag ** 2) magnitude = torch.sqrt(real ** 2 + imag ** 2)
@ -647,7 +648,7 @@ class MelSTFT(nn.Module):
""" """
magnitude, phase = self.stft_fn(y) magnitude, phase = self.stft_fn(y)
energy = torch.norm(magnitude, dim=1) energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5)) log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy return log_mel, magnitude, phase, energy

View File

@ -939,7 +939,7 @@ def text_encoder_offload_device():
def text_encoder_device(): def text_encoder_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
if should_use_fp16(prioritize_performance=False): if should_use_fp16(prioritize_performance=False):
return get_torch_device() return get_torch_device()
else: else:

View File

@ -715,8 +715,8 @@ class ModelPatcher:
default = True # default random weights in non leaf modules default = True # default random weights in non leaf modules
break break
if default and default_device is not None: if default and default_device is not None:
for param in params.values(): for param_name, param in params.items():
param.data = param.data.to(device=default_device) param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None))
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m) module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem module_offload_mem = module_mem

View File

@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
offload_stream = None offload_stream = None
xfer_dest = None xfer_dest = None

View File

@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel):
aspect_ratio: str = Field(...) aspect_ratio: str = Field(...)
n: int = Field(...) n: int = Field(...)
seed: int = Field(...) seed: int = Field(...)
response_for: str = Field("url") response_format: str = Field("url")
resolution: str = Field(...)
class InputUrlObject(BaseModel): class InputUrlObject(BaseModel):
@ -16,12 +17,13 @@ class InputUrlObject(BaseModel):
class ImageEditRequest(BaseModel): class ImageEditRequest(BaseModel):
model: str = Field(...) model: str = Field(...)
image: InputUrlObject = Field(...) images: list[InputUrlObject] = Field(...)
prompt: str = Field(...) prompt: str = Field(...)
resolution: str = Field(...) resolution: str = Field(...)
n: int = Field(...) n: int = Field(...)
seed: int = Field(...) seed: int = Field(...)
response_for: str = Field("url") response_format: str = Field("url")
aspect_ratio: str | None = Field(...)
class VideoGenerationRequest(BaseModel): class VideoGenerationRequest(BaseModel):
@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel):
revised_prompt: str | None = Field(None) revised_prompt: str | None = Field(None)
class UsageObject(BaseModel):
cost_in_usd_ticks: int | None = Field(None)
class ImageGenerationResponse(BaseModel): class ImageGenerationResponse(BaseModel):
data: list[ImageResponseObject] = Field(...) data: list[ImageResponseObject] = Field(...)
usage: UsageObject | None = Field(None)
class VideoGenerationResponse(BaseModel): class VideoGenerationResponse(BaseModel):
@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel):
status: str | None = Field(None) status: str | None = Field(None)
video: VideoResponseObject | None = Field(None) video: VideoResponseObject | None = Field(None)
model: str | None = Field(None) model: str | None = Field(None)
usage: UsageObject | None = Field(None)

View File

@ -66,13 +66,17 @@ class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...) JobId: str = Field(...)
class To3DUVFileInput(BaseModel): class TaskFile3DInput(BaseModel):
Type: str = Field(..., description="File type: GLB, OBJ, or FBX") Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
Url: str = Field(...) Url: str = Field(...)
class To3DUVTaskRequest(BaseModel): class To3DUVTaskRequest(BaseModel):
File: To3DUVFileInput = Field(...) File: TaskFile3DInput = Field(...)
class To3DPartTaskRequest(BaseModel):
File: TaskFile3DInput = Field(...)
class TextureEditImageInfo(BaseModel): class TextureEditImageInfo(BaseModel):
@ -80,7 +84,13 @@ class TextureEditImageInfo(BaseModel):
class TextureEditTaskRequest(BaseModel): class TextureEditTaskRequest(BaseModel):
File3D: To3DUVFileInput = Field(...) File3D: TaskFile3DInput = Field(...)
Image: TextureEditImageInfo | None = Field(None) Image: TextureEditImageInfo | None = Field(None)
Prompt: str | None = Field(None) Prompt: str | None = Field(None)
EnablePBR: bool | None = Field(None) EnablePBR: bool | None = Field(None)
class SmartTopologyRequest(BaseModel):
File3D: TaskFile3DInput = Field(...)
PolygonType: str | None = Field(...)
FaceLevel: str | None = Field(...)

View File

@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel):
keep_original_sound: str = Field(...) keep_original_sound: str = Field(...)
character_orientation: str = Field(...) character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'") mode: str = Field(..., description="'pro' or 'std'")
model_name: str = Field(...)

View File

@ -72,18 +72,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
) )
class GeminiModel(str, Enum):
"""
Gemini Model Names allowed by comfy-api
"""
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
gemini_2_5_pro = "gemini-2.5-pro"
gemini_2_5_flash = "gemini-2.5-flash"
gemini_3_0_pro = "gemini-3-pro-preview"
class GeminiImageModel(str, Enum): class GeminiImageModel(str, Enum):
""" """
Gemini Image Model Names allowed by comfy-api Gemini Image Model Names allowed by comfy-api
@ -237,10 +225,14 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
input_tokens_price = 0.30 input_tokens_price = 0.30
output_text_tokens_price = 2.50 output_text_tokens_price = 2.50
output_image_tokens_price = 30.0 output_image_tokens_price = 30.0
elif response.modelVersion == "gemini-3-pro-preview": elif response.modelVersion in ("gemini-3-pro-preview", "gemini-3.1-pro-preview"):
input_tokens_price = 2 input_tokens_price = 2
output_text_tokens_price = 12.0 output_text_tokens_price = 12.0
output_image_tokens_price = 0.0 output_image_tokens_price = 0.0
elif response.modelVersion == "gemini-3.1-flash-lite-preview":
input_tokens_price = 0.25
output_text_tokens_price = 1.50
output_image_tokens_price = 0.0
elif response.modelVersion == "gemini-3-pro-image-preview": elif response.modelVersion == "gemini-3-pro-image-preview":
input_tokens_price = 2 input_tokens_price = 2
output_text_tokens_price = 12.0 output_text_tokens_price = 12.0
@ -292,8 +284,16 @@ class GeminiNode(IO.ComfyNode):
), ),
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=GeminiModel, options=[
default=GeminiModel.gemini_2_5_pro, "gemini-2.5-pro-preview-05-06",
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-3-pro-preview",
"gemini-3-1-pro",
"gemini-3-1-flash-lite",
],
default="gemini-3-1-pro",
tooltip="The Gemini model to use for generating responses.", tooltip="The Gemini model to use for generating responses.",
), ),
IO.Int.Input( IO.Int.Input(
@ -363,11 +363,16 @@ class GeminiNode(IO.ComfyNode):
"usd": [0.00125, 0.01], "usd": [0.00125, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
} }
: $contains($m, "gemini-3-pro-preview") ? { : ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? {
"type": "list_usd", "type": "list_usd",
"usd": [0.002, 0.012], "usd": [0.002, 0.012],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
} }
: $contains($m, "gemini-3-1-flash-lite") ? {
"type": "list_usd",
"usd": [0.00025, 0.0015],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type":"text", "text":"Token-based"} : {"type":"text", "text":"Token-based"}
) )
""", """,
@ -436,12 +441,14 @@ class GeminiNode(IO.ComfyNode):
files: list[GeminiPart] | None = None, files: list[GeminiPart] | None = None,
system_prompt: str = "", system_prompt: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) if model == "gemini-3-pro-preview":
model = "gemini-3.1-pro-preview" # model "gemini-3-pro-preview" will be soon deprecated by Google
elif model == "gemini-3-1-pro":
model = "gemini-3.1-pro-preview"
elif model == "gemini-3-1-flash-lite":
model = "gemini-3.1-flash-lite-preview"
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [GeminiPart(text=prompt)] parts: list[GeminiPart] = [GeminiPart(text=prompt)]
# Add other modal parts
if images is not None: if images is not None:
parts.extend(await create_image_parts(cls, images)) parts.extend(await create_image_parts(cls, images))
if audio is not None: if audio is not None:

View File

@ -27,6 +27,12 @@ from comfy_api_nodes.util import (
) )
def _extract_grok_price(response) -> float | None:
if response.usage and response.usage.cost_in_usd_ticks is not None:
return response.usage.cost_in_usd_ticks / 10_000_000_000
return None
class GrokImageNode(IO.ComfyNode): class GrokImageNode(IO.ComfyNode):
@classmethod @classmethod
@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode):
category="api node/image/Grok", category="api node/image/Grok",
description="Generate images using Grok based on a text prompt", description="Generate images using Grok based on a text prompt",
inputs=[ inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]), IO.Combo.Input(
"model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; " tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.", "actual results are nondeterministic regardless of seed.",
), ),
IO.Combo.Input("resolution", options=["1K", "2K"], optional=True),
], ],
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""", expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": $rate * widgets.number_of_images}
)
""",
), ),
) )
@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
number_of_images: int, number_of_images: int,
seed: int, seed: int,
resolution: str = "1K",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
response = await sync_op( response = await sync_op(
@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
n=number_of_images, n=number_of_images,
seed=seed, seed=seed,
resolution=resolution.lower(),
), ),
response_model=ImageGenerationResponse, response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
) )
if len(response.data) == 1: if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode):
category="api node/image/Grok", category="api node/image/Grok",
description="Modify an existing image based on a text prompt", description="Modify an existing image based on a text prompt",
inputs=[ inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]), IO.Combo.Input(
IO.Image.Input("image"), "model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.Image.Input("image", display_name="images"),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
tooltip="The text prompt used to generate the image", tooltip="The text prompt used to generate the image",
), ),
IO.Combo.Input("resolution", options=["1K"]), IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Int.Input( IO.Int.Input(
"number_of_images", "number_of_images",
default=1, default=1,
@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; " tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.", "actual results are nondeterministic regardless of seed.",
), ),
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"9:16",
"16:9",
"9:19.5",
"19.5:9",
"9:20",
"20:9",
"1:2",
"2:1",
],
optional=True,
tooltip="Only allowed when multiple images are connected to the image input.",
),
], ],
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""", expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": 0.002 + $rate * widgets.number_of_images}
)
""",
), ),
) )
@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode):
resolution: str, resolution: str,
number_of_images: int, number_of_images: int,
seed: int, seed: int,
aspect_ratio: str = "auto",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1: if model == "grok-imagine-image-pro":
raise ValueError("Only one input image is supported.") if get_number_of_images(image) > 1:
raise ValueError("The pro model supports only 1 input image.")
elif get_number_of_images(image) > 3:
raise ValueError("A maximum of 3 input images is supported.")
if aspect_ratio != "auto" and get_number_of_images(image) == 1:
raise ValueError(
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
)
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
data=ImageEditRequest( data=ImageEditRequest(
model=model, model=model,
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"), images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image],
prompt=prompt, prompt=prompt,
resolution=resolution.lower(), resolution=resolution.lower(),
n=number_of_images, n=number_of_images,
seed=seed, seed=seed,
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
), ),
response_model=ImageGenerationResponse, response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
) )
if len(response.data) == 1: if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode):
category="api node/video/Grok", category="api node/video/Grok",
description="Generate video from a prompt or an image", description="Generate video from a prompt or an image",
inputs=[ inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]), IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
expr=""" expr="""
( (
$base := 0.181 * widgets.duration; $rate := widgets.resolution = "720p" ? 0.07 : 0.05;
$base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
) )
""", """,
@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete", status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse, response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
) )
return IO.NodeOutput(await download_url_to_video_output(response.video.url)) return IO.NodeOutput(await download_url_to_video_output(response.video.url))
@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode):
category="api node/video/Grok", category="api node/video/Grok",
description="Edit an existing video based on a text prompt.", description="Edit an existing video based on a text prompt.",
inputs=[ inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]), IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""", expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""",
), ),
) )
@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete", status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse, response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
) )
return IO.NodeOutput(await download_url_to_video_output(response.video.url)) return IO.NodeOutput(await download_url_to_video_output(response.video.url))

View File

@ -5,18 +5,19 @@ from comfy_api_nodes.apis.hunyuan3d import (
Hunyuan3DViewImage, Hunyuan3DViewImage,
InputGenerateType, InputGenerateType,
ResultFile3D, ResultFile3D,
SmartTopologyRequest,
TaskFile3DInput,
TextureEditTaskRequest, TextureEditTaskRequest,
To3DPartTaskRequest,
To3DProTaskCreateResponse, To3DProTaskCreateResponse,
To3DProTaskQueryRequest, To3DProTaskQueryRequest,
To3DProTaskRequest, To3DProTaskRequest,
To3DProTaskResultResponse, To3DProTaskResultResponse,
To3DUVFileInput,
To3DUVTaskRequest, To3DUVTaskRequest,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
download_url_to_file_3d, download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side, downscale_image_tensor_by_max_side,
poll_op, poll_op,
sync_op, sync_op,
@ -344,7 +345,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
outputs=[ outputs=[
IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DOBJ.Output(display_name="OBJ"),
IO.File3DFBX.Output(display_name="FBX"), IO.File3DFBX.Output(display_name="FBX"),
IO.Image.Output(),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -375,7 +375,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"), ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
response_model=To3DProTaskCreateResponse, response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest( data=To3DUVTaskRequest(
File=To3DUVFileInput( File=TaskFile3DInput(
Type=file_format.upper(), Type=file_format.upper(),
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format), Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
) )
@ -394,7 +394,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
return IO.NodeOutput( return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
) )
@ -463,7 +462,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"), ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
response_model=To3DProTaskCreateResponse, response_model=To3DProTaskCreateResponse,
data=TextureEditTaskRequest( data=TextureEditTaskRequest(
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url), File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
Prompt=prompt, Prompt=prompt,
EnablePBR=True, EnablePBR=True,
), ),
@ -538,8 +537,8 @@ class Tencent3DPartNode(IO.ComfyNode):
cls, cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"), ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
response_model=To3DProTaskCreateResponse, response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest( data=To3DPartTaskRequest(
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url), File=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
), ),
is_rate_limited=_is_tencent_rate_limited, is_rate_limited=_is_tencent_rate_limited,
) )
@ -557,15 +556,107 @@ class Tencent3DPartNode(IO.ComfyNode):
) )
class TencentSmartTopologyNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TencentSmartTopologyNode",
display_name="Hunyuan3D: Smart Topology",
category="api node/3d/Tencent",
description="Perform smart retopology on a 3D model. "
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DAny],
tooltip="Input 3D model (GLB or OBJ)",
),
IO.Combo.Input(
"polygon_type",
options=["triangle", "quadrilateral"],
tooltip="Surface composition type.",
),
IO.Combo.Input(
"face_level",
options=["medium", "high", "low"],
tooltip="Polygon reduction level.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DOBJ.Output(display_name="OBJ"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'),
)
SUPPORTED_FORMATS = {"glb", "obj"}
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
polygon_type: str,
face_level: str,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format not in cls.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
)
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"),
response_model=To3DProTaskCreateResponse,
data=SmartTopologyRequest(
File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
PolygonType=polygon_type,
FaceLevel=face_level,
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
)
class TencentHunyuan3DExtension(ComfyExtension): class TencentHunyuan3DExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
TencentTextToModelNode, TencentTextToModelNode,
TencentImageToModelNode, TencentImageToModelNode,
# TencentModelTo3DUVNode, TencentModelTo3DUVNode,
# Tencent3DTextureEditNode, # Tencent3DTextureEditNode,
Tencent3DPartNode, Tencent3DPartNode,
TencentSmartTopologyNode,
] ]

View File

@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode):
"but the character orientation matches the reference image (camera/other details via prompt).", "but the character orientation matches the reference image (camera/other details via prompt).",
), ),
IO.Combo.Input("mode", options=["pro", "std"]), IO.Combo.Input("mode", options=["pro", "std"]),
IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound: bool, keep_original_sound: bool,
character_orientation: str, character_orientation: str,
mode: str, mode: str,
model: str = "kling-v2-6",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, max_length=2500) validate_string(prompt, max_length=2500)
validate_image_dimensions(reference_image, min_width=340, min_height=340) validate_image_dimensions(reference_image, min_width=340, min_height=340)
@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound="yes" if keep_original_sound else "no", keep_original_sound="yes" if keep_original_sound else "no",
character_orientation=character_orientation, character_orientation=character_orientation,
mode=mode, mode=mode,
model_name=model,
), ),
) )
if response.code: if response.code:

View File

@ -83,7 +83,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
async def sync_op( async def sync_op(

View File

@ -253,10 +253,12 @@ class LTXVAddGuide(io.ComfyNode):
return frame_idx, latent_idx return frame_idx, latent_idx
@classmethod @classmethod
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1): def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1, causal_fix=None):
keyframe_idxs, _ = get_keyframe_idxs(cond) keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 if causal_fix is None:
causal_fix = frame_idx == 0 or guiding_latent.shape[2] == 1
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=causal_fix)
pixel_coords[:, 0] += frame_idx pixel_coords[:, 0] += frame_idx
# The following adjusts keyframe end positions for small grid IC-LoRA. # The following adjusts keyframe end positions for small grid IC-LoRA.
@ -278,12 +280,12 @@ class LTXVAddGuide(io.ComfyNode):
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
@classmethod @classmethod
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1): def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1, causal_fix=None):
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
raise ValueError("Adding guide to a combined AV latent is not supported.") raise ValueError("Adding guide to a combined AV latent is not supported.")
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix)
if guide_mask is not None: if guide_mask is not None:
target_h = max(noise_mask.shape[3], guide_mask.shape[3]) target_h = max(noise_mask.shape[3], guide_mask.shape[3])

119
comfy_extras/nodes_math.py Normal file
View File

@ -0,0 +1,119 @@
"""Math expression node using simpleeval for safe evaluation.
Provides a ComfyMathExpression node that evaluates math expressions
against dynamically-grown numeric inputs.
"""
from __future__ import annotations
import math
import string
from simpleeval import simple_eval
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
MAX_EXPONENT = 4000
def _variadic_sum(*args):
"""Support both sum(values) and sum(a, b, c)."""
if len(args) == 1 and hasattr(args[0], "__iter__"):
return sum(args[0])
return sum(args)
def _safe_pow(base, exp):
"""Wrap pow() with an exponent cap to prevent DoS via huge exponents.
The ** operator is already guarded by simpleeval's safe_power, but
pow() as a callable bypasses that guard.
"""
if abs(exp) > MAX_EXPONENT:
raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})")
return pow(base, exp)
MATH_FUNCTIONS = {
"sum": _variadic_sum,
"min": min,
"max": max,
"abs": abs,
"round": round,
"pow": _safe_pow,
"sqrt": math.sqrt,
"ceil": math.ceil,
"floor": math.floor,
"log": math.log,
"log2": math.log2,
"log10": math.log10,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"int": int,
"float": float,
}
class MathExpressionNode(io.ComfyNode):
"""Evaluates a math expression against dynamically-grown inputs."""
@classmethod
def define_schema(cls) -> io.Schema:
autogrow = io.Autogrow.TemplateNames(
input=io.MultiType.Input("value", [io.Float, io.Int]),
names=list(string.ascii_lowercase),
min=1,
)
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="math",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",
],
inputs=[
io.String.Input("expression", default="a + b", multiline=True),
io.Autogrow.Input("values", template=autogrow),
],
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
],
)
@classmethod
def execute(
cls, expression: str, values: io.Autogrow.Type
) -> io.NodeOutput:
if not expression.strip():
raise ValueError("Expression cannot be empty.")
context: dict = dict(values)
context["values"] = list(values.values())
result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
# bool check must come first because bool is a subclass of int in Python
if isinstance(result, bool) or not isinstance(result, (int, float)):
raise ValueError(
f"Math Expression '{expression}' must evaluate to a numeric result, "
f"got {type(result).__name__}: {result!r}"
)
if not math.isfinite(result):
raise ValueError(
f"Math Expression '{expression}' produced a non-finite result: {result}"
)
return io.NodeOutput(float(result), int(result))
class MathExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [MathExpressionNode]
async def comfy_entrypoint() -> MathExtension:
return MathExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.16.0" __version__ = "0.16.3"

View File

@ -2449,6 +2449,7 @@ async def init_builtin_extra_nodes():
"nodes_replacements.py", "nodes_replacements.py",
"nodes_nag.py", "nodes_nag.py",
"nodes_sdpose.py", "nodes_sdpose.py",
"nodes_math.py",
] ]
import_failed = [] import_failed = []

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.16.0" version = "0.16.3"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19 comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.7 comfyui-workflow-templates==0.9.10
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.3
torch torch
torchsde torchsde
@ -22,8 +22,9 @@ alembic
SQLAlchemy SQLAlchemy
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.6 comfy-aimdo>=0.2.7
requests requests
simpleeval>=1.0.0
#non essential dependencies: #non essential dependencies:
kornia>=0.7.1 kornia>=0.7.1

View File

@ -0,0 +1,197 @@
import math
import pytest
from collections import OrderedDict
from unittest.mock import patch, MagicMock
mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
mock_server = MagicMock()
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
from comfy_extras.nodes_math import MathExpressionNode
class TestMathExpressionExecute:
@staticmethod
def _exec(expression: str, **kwargs) -> object:
values = OrderedDict(kwargs)
return MathExpressionNode.execute(expression, values)
def test_addition(self):
result = self._exec("a + b", a=3, b=4)
assert result[0] == 7.0
assert result[1] == 7
def test_subtraction(self):
result = self._exec("a - b", a=10, b=3)
assert result[0] == 7.0
assert result[1] == 7
def test_multiplication(self):
result = self._exec("a * b", a=3, b=5)
assert result[0] == 15.0
assert result[1] == 15
def test_division(self):
result = self._exec("a / b", a=10, b=4)
assert result[0] == 2.5
assert result[1] == 2
def test_single_input(self):
result = self._exec("a * 2", a=5)
assert result[0] == 10.0
assert result[1] == 10
def test_three_inputs(self):
result = self._exec("a + b + c", a=1, b=2, c=3)
assert result[0] == 6.0
assert result[1] == 6
def test_float_inputs(self):
result = self._exec("a + b", a=1.5, b=2.5)
assert result[0] == 4.0
assert result[1] == 4
def test_mixed_int_float_inputs(self):
result = self._exec("a * b", a=1024, b=1.5)
assert result[0] == 1536.0
assert result[1] == 1536
def test_mixed_resolution_scale(self):
result = self._exec("a * b", a=512, b=0.75)
assert result[0] == 384.0
assert result[1] == 384
def test_sum_values_array(self):
result = self._exec("sum(values)", a=1, b=2, c=3)
assert result[0] == 6.0
def test_sum_variadic(self):
result = self._exec("sum(a, b, c)", a=1, b=2, c=3)
assert result[0] == 6.0
def test_min_values(self):
result = self._exec("min(values)", a=5, b=2, c=8)
assert result[0] == 2.0
def test_max_values(self):
result = self._exec("max(values)", a=5, b=2, c=8)
assert result[0] == 8.0
def test_abs_function(self):
result = self._exec("abs(a)", a=-7)
assert result[0] == 7.0
assert result[1] == 7
def test_sqrt(self):
result = self._exec("sqrt(a)", a=16)
assert result[0] == 4.0
assert result[1] == 4
def test_ceil(self):
result = self._exec("ceil(a)", a=2.3)
assert result[0] == 3.0
assert result[1] == 3
def test_floor(self):
result = self._exec("floor(a)", a=2.7)
assert result[0] == 2.0
assert result[1] == 2
def test_sin(self):
result = self._exec("sin(a)", a=0)
assert result[0] == 0.0
def test_log10(self):
result = self._exec("log10(a)", a=100)
assert result[0] == 2.0
assert result[1] == 2
def test_float_output_type(self):
result = self._exec("a + b", a=1, b=2)
assert isinstance(result[0], float)
def test_int_output_type(self):
result = self._exec("a + b", a=1, b=2)
assert isinstance(result[1], int)
def test_non_numeric_result_raises(self):
with pytest.raises(ValueError, match="must evaluate to a numeric result"):
self._exec("'hello'", a=42)
def test_undefined_function_raises(self):
with pytest.raises(Exception, match="not defined"):
self._exec("str(a)", a=42)
def test_boolean_result_raises(self):
with pytest.raises(ValueError, match="got bool"):
self._exec("a > b", a=5, b=3)
def test_empty_expression_raises(self):
with pytest.raises(ValueError, match="Expression cannot be empty"):
self._exec("", a=1)
def test_whitespace_only_expression_raises(self):
with pytest.raises(ValueError, match="Expression cannot be empty"):
self._exec(" ", a=1)
# --- Missing function coverage (round, pow, log, log2, cos, tan) ---
def test_round(self):
result = self._exec("round(a)", a=2.7)
assert result[0] == 3.0
assert result[1] == 3
def test_round_with_ndigits(self):
result = self._exec("round(a, 2)", a=3.14159)
assert result[0] == pytest.approx(3.14)
def test_pow(self):
result = self._exec("pow(a, b)", a=2, b=10)
assert result[0] == 1024.0
assert result[1] == 1024
def test_log(self):
result = self._exec("log(a)", a=math.e)
assert result[0] == pytest.approx(1.0)
def test_log2(self):
result = self._exec("log2(a)", a=8)
assert result[0] == pytest.approx(3.0)
def test_cos(self):
result = self._exec("cos(a)", a=0)
assert result[0] == 1.0
def test_tan(self):
result = self._exec("tan(a)", a=0)
assert result[0] == 0.0
# --- int/float converter functions ---
def test_int_converter(self):
result = self._exec("int(a / b)", a=7, b=2)
assert result[1] == 3
def test_float_converter(self):
result = self._exec("float(a)", a=5)
assert result[0] == 5.0
# --- Error path tests ---
def test_division_by_zero_raises(self):
with pytest.raises(ZeroDivisionError):
self._exec("a / b", a=1, b=0)
def test_sqrt_negative_raises(self):
with pytest.raises(ValueError, match="math domain error"):
self._exec("sqrt(a)", a=-1)
def test_overflow_inf_raises(self):
with pytest.raises(ValueError, match="non-finite result"):
self._exec("a * b", a=1e308, b=10)
def test_pow_huge_exponent_raises(self):
with pytest.raises(ValueError, match="Exponent .* exceeds maximum"):
self._exec("pow(a, b)", a=10, b=10000000)