mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-19 06:52:31 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
0aeeb8609d
@ -27,6 +27,7 @@ class AudioEncoderModel():
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
self.model_sample_rate = 16000
|
self.model_sample_rate = 16000
|
||||||
|
comfy.model_management.archive_model_dtypes(self.model)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||||
|
|||||||
@ -939,7 +939,7 @@ def text_encoder_offload_device():
|
|||||||
def text_encoder_device():
|
def text_encoder_device():
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
|
||||||
if should_use_fp16(prioritize_performance=False):
|
if should_use_fp16(prioritize_performance=False):
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -715,8 +715,8 @@ class ModelPatcher:
|
|||||||
default = True # default random weights in non leaf modules
|
default = True # default random weights in non leaf modules
|
||||||
break
|
break
|
||||||
if default and default_device is not None:
|
if default and default_device is not None:
|
||||||
for param in params.values():
|
for param_name, param in params.items():
|
||||||
param.data = param.data.to(device=default_device)
|
param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None))
|
||||||
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
module_offload_mem = module_mem
|
module_offload_mem = module_mem
|
||||||
|
|||||||
@ -66,13 +66,17 @@ class To3DProTaskQueryRequest(BaseModel):
|
|||||||
JobId: str = Field(...)
|
JobId: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class To3DUVFileInput(BaseModel):
|
class TaskFile3DInput(BaseModel):
|
||||||
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
|
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
|
||||||
Url: str = Field(...)
|
Url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class To3DUVTaskRequest(BaseModel):
|
class To3DUVTaskRequest(BaseModel):
|
||||||
File: To3DUVFileInput = Field(...)
|
File: TaskFile3DInput = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class To3DPartTaskRequest(BaseModel):
|
||||||
|
File: TaskFile3DInput = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class TextureEditImageInfo(BaseModel):
|
class TextureEditImageInfo(BaseModel):
|
||||||
@ -80,7 +84,13 @@ class TextureEditImageInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TextureEditTaskRequest(BaseModel):
|
class TextureEditTaskRequest(BaseModel):
|
||||||
File3D: To3DUVFileInput = Field(...)
|
File3D: TaskFile3DInput = Field(...)
|
||||||
Image: TextureEditImageInfo | None = Field(None)
|
Image: TextureEditImageInfo | None = Field(None)
|
||||||
Prompt: str | None = Field(None)
|
Prompt: str | None = Field(None)
|
||||||
EnablePBR: bool | None = Field(None)
|
EnablePBR: bool | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class SmartTopologyRequest(BaseModel):
|
||||||
|
File3D: TaskFile3DInput = Field(...)
|
||||||
|
PolygonType: str | None = Field(...)
|
||||||
|
FaceLevel: str | None = Field(...)
|
||||||
|
|||||||
@ -72,18 +72,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GeminiModel(str, Enum):
|
|
||||||
"""
|
|
||||||
Gemini Model Names allowed by comfy-api
|
|
||||||
"""
|
|
||||||
|
|
||||||
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
|
|
||||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
|
||||||
gemini_2_5_pro = "gemini-2.5-pro"
|
|
||||||
gemini_2_5_flash = "gemini-2.5-flash"
|
|
||||||
gemini_3_0_pro = "gemini-3-pro-preview"
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageModel(str, Enum):
|
class GeminiImageModel(str, Enum):
|
||||||
"""
|
"""
|
||||||
Gemini Image Model Names allowed by comfy-api
|
Gemini Image Model Names allowed by comfy-api
|
||||||
@ -237,10 +225,14 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
|||||||
input_tokens_price = 0.30
|
input_tokens_price = 0.30
|
||||||
output_text_tokens_price = 2.50
|
output_text_tokens_price = 2.50
|
||||||
output_image_tokens_price = 30.0
|
output_image_tokens_price = 30.0
|
||||||
elif response.modelVersion == "gemini-3-pro-preview":
|
elif response.modelVersion in ("gemini-3-pro-preview", "gemini-3.1-pro-preview"):
|
||||||
input_tokens_price = 2
|
input_tokens_price = 2
|
||||||
output_text_tokens_price = 12.0
|
output_text_tokens_price = 12.0
|
||||||
output_image_tokens_price = 0.0
|
output_image_tokens_price = 0.0
|
||||||
|
elif response.modelVersion == "gemini-3.1-flash-lite-preview":
|
||||||
|
input_tokens_price = 0.25
|
||||||
|
output_text_tokens_price = 1.50
|
||||||
|
output_image_tokens_price = 0.0
|
||||||
elif response.modelVersion == "gemini-3-pro-image-preview":
|
elif response.modelVersion == "gemini-3-pro-image-preview":
|
||||||
input_tokens_price = 2
|
input_tokens_price = 2
|
||||||
output_text_tokens_price = 12.0
|
output_text_tokens_price = 12.0
|
||||||
@ -292,8 +284,16 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=GeminiModel,
|
options=[
|
||||||
default=GeminiModel.gemini_2_5_pro,
|
"gemini-2.5-pro-preview-05-06",
|
||||||
|
"gemini-2.5-flash-preview-04-17",
|
||||||
|
"gemini-2.5-pro",
|
||||||
|
"gemini-2.5-flash",
|
||||||
|
"gemini-3-pro-preview",
|
||||||
|
"gemini-3-1-pro",
|
||||||
|
"gemini-3-1-flash-lite",
|
||||||
|
],
|
||||||
|
default="gemini-3-1-pro",
|
||||||
tooltip="The Gemini model to use for generating responses.",
|
tooltip="The Gemini model to use for generating responses.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -363,11 +363,16 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
"usd": [0.00125, 0.01],
|
"usd": [0.00125, 0.01],
|
||||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||||
}
|
}
|
||||||
: $contains($m, "gemini-3-pro-preview") ? {
|
: ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? {
|
||||||
"type": "list_usd",
|
"type": "list_usd",
|
||||||
"usd": [0.002, 0.012],
|
"usd": [0.002, 0.012],
|
||||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||||
}
|
}
|
||||||
|
: $contains($m, "gemini-3-1-flash-lite") ? {
|
||||||
|
"type": "list_usd",
|
||||||
|
"usd": [0.00025, 0.0015],
|
||||||
|
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||||
|
}
|
||||||
: {"type":"text", "text":"Token-based"}
|
: {"type":"text", "text":"Token-based"}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
@ -436,12 +441,14 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
files: list[GeminiPart] | None = None,
|
files: list[GeminiPart] | None = None,
|
||||||
system_prompt: str = "",
|
system_prompt: str = "",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
if model == "gemini-3-pro-preview":
|
||||||
|
model = "gemini-3.1-pro-preview" # model "gemini-3-pro-preview" will be soon deprecated by Google
|
||||||
|
elif model == "gemini-3-1-pro":
|
||||||
|
model = "gemini-3.1-pro-preview"
|
||||||
|
elif model == "gemini-3-1-flash-lite":
|
||||||
|
model = "gemini-3.1-flash-lite-preview"
|
||||||
|
|
||||||
# Create parts list with text prompt as the first part
|
|
||||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
|
|
||||||
# Add other modal parts
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
parts.extend(await create_image_parts(cls, images))
|
parts.extend(await create_image_parts(cls, images))
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
|
|||||||
@ -5,18 +5,19 @@ from comfy_api_nodes.apis.hunyuan3d import (
|
|||||||
Hunyuan3DViewImage,
|
Hunyuan3DViewImage,
|
||||||
InputGenerateType,
|
InputGenerateType,
|
||||||
ResultFile3D,
|
ResultFile3D,
|
||||||
|
SmartTopologyRequest,
|
||||||
|
TaskFile3DInput,
|
||||||
TextureEditTaskRequest,
|
TextureEditTaskRequest,
|
||||||
|
To3DPartTaskRequest,
|
||||||
To3DProTaskCreateResponse,
|
To3DProTaskCreateResponse,
|
||||||
To3DProTaskQueryRequest,
|
To3DProTaskQueryRequest,
|
||||||
To3DProTaskRequest,
|
To3DProTaskRequest,
|
||||||
To3DProTaskResultResponse,
|
To3DProTaskResultResponse,
|
||||||
To3DUVFileInput,
|
|
||||||
To3DUVTaskRequest,
|
To3DUVTaskRequest,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
download_url_to_file_3d,
|
download_url_to_file_3d,
|
||||||
download_url_to_image_tensor,
|
|
||||||
downscale_image_tensor_by_max_side,
|
downscale_image_tensor_by_max_side,
|
||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
@ -344,7 +345,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
|||||||
outputs=[
|
outputs=[
|
||||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||||
IO.File3DFBX.Output(display_name="FBX"),
|
IO.File3DFBX.Output(display_name="FBX"),
|
||||||
IO.Image.Output(),
|
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
@ -375,7 +375,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
|||||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
|
||||||
response_model=To3DProTaskCreateResponse,
|
response_model=To3DProTaskCreateResponse,
|
||||||
data=To3DUVTaskRequest(
|
data=To3DUVTaskRequest(
|
||||||
File=To3DUVFileInput(
|
File=TaskFile3DInput(
|
||||||
Type=file_format.upper(),
|
Type=file_format.upper(),
|
||||||
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
|
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
|
||||||
)
|
)
|
||||||
@ -394,7 +394,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||||
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -463,7 +462,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
|||||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
|
||||||
response_model=To3DProTaskCreateResponse,
|
response_model=To3DProTaskCreateResponse,
|
||||||
data=TextureEditTaskRequest(
|
data=TextureEditTaskRequest(
|
||||||
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
|
||||||
Prompt=prompt,
|
Prompt=prompt,
|
||||||
EnablePBR=True,
|
EnablePBR=True,
|
||||||
),
|
),
|
||||||
@ -538,8 +537,8 @@ class Tencent3DPartNode(IO.ComfyNode):
|
|||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
|
||||||
response_model=To3DProTaskCreateResponse,
|
response_model=To3DProTaskCreateResponse,
|
||||||
data=To3DUVTaskRequest(
|
data=To3DPartTaskRequest(
|
||||||
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
File=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
|
||||||
),
|
),
|
||||||
is_rate_limited=_is_tencent_rate_limited,
|
is_rate_limited=_is_tencent_rate_limited,
|
||||||
)
|
)
|
||||||
@ -557,15 +556,107 @@ class Tencent3DPartNode(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentSmartTopologyNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="TencentSmartTopologyNode",
|
||||||
|
display_name="Hunyuan3D: Smart Topology",
|
||||||
|
category="api node/3d/Tencent",
|
||||||
|
description="Perform smart retopology on a 3D model. "
|
||||||
|
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
|
||||||
|
inputs=[
|
||||||
|
IO.MultiType.Input(
|
||||||
|
"model_3d",
|
||||||
|
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DAny],
|
||||||
|
tooltip="Input 3D model (GLB or OBJ)",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"polygon_type",
|
||||||
|
options=["triangle", "quadrilateral"],
|
||||||
|
tooltip="Surface composition type.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"face_level",
|
||||||
|
options=["medium", "high", "low"],
|
||||||
|
tooltip="Polygon reduction level.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
SUPPORTED_FORMATS = {"glb", "obj"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_3d: Types.File3D,
|
||||||
|
polygon_type: str,
|
||||||
|
face_level: str,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
_ = seed
|
||||||
|
file_format = model_3d.format.lower()
|
||||||
|
if file_format not in cls.SUPPORTED_FORMATS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
|
||||||
|
)
|
||||||
|
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"),
|
||||||
|
response_model=To3DProTaskCreateResponse,
|
||||||
|
data=SmartTopologyRequest(
|
||||||
|
File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
|
||||||
|
PolygonType=polygon_type,
|
||||||
|
FaceLevel=face_level,
|
||||||
|
),
|
||||||
|
is_rate_limited=_is_tencent_rate_limited,
|
||||||
|
)
|
||||||
|
if response.Error:
|
||||||
|
raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}")
|
||||||
|
result = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"),
|
||||||
|
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||||
|
response_model=To3DProTaskResultResponse,
|
||||||
|
status_extractor=lambda r: r.Status,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(
|
||||||
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TencentHunyuan3DExtension(ComfyExtension):
|
class TencentHunyuan3DExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
TencentTextToModelNode,
|
TencentTextToModelNode,
|
||||||
TencentImageToModelNode,
|
TencentImageToModelNode,
|
||||||
# TencentModelTo3DUVNode,
|
TencentModelTo3DUVNode,
|
||||||
# Tencent3DTextureEditNode,
|
# Tencent3DTextureEditNode,
|
||||||
Tencent3DPartNode,
|
Tencent3DPartNode,
|
||||||
|
TencentSmartTopologyNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -83,7 +83,7 @@ class _PollUIState:
|
|||||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
||||||
|
|
||||||
|
|
||||||
async def sync_op(
|
async def sync_op(
|
||||||
|
|||||||
119
comfy_extras/nodes_math.py
Normal file
119
comfy_extras/nodes_math.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
"""Math expression node using simpleeval for safe evaluation.
|
||||||
|
|
||||||
|
Provides a ComfyMathExpression node that evaluates math expressions
|
||||||
|
against dynamically-grown numeric inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import string
|
||||||
|
|
||||||
|
from simpleeval import simple_eval
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
MAX_EXPONENT = 4000
|
||||||
|
|
||||||
|
|
||||||
|
def _variadic_sum(*args):
|
||||||
|
"""Support both sum(values) and sum(a, b, c)."""
|
||||||
|
if len(args) == 1 and hasattr(args[0], "__iter__"):
|
||||||
|
return sum(args[0])
|
||||||
|
return sum(args)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_pow(base, exp):
|
||||||
|
"""Wrap pow() with an exponent cap to prevent DoS via huge exponents.
|
||||||
|
|
||||||
|
The ** operator is already guarded by simpleeval's safe_power, but
|
||||||
|
pow() as a callable bypasses that guard.
|
||||||
|
"""
|
||||||
|
if abs(exp) > MAX_EXPONENT:
|
||||||
|
raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})")
|
||||||
|
return pow(base, exp)
|
||||||
|
|
||||||
|
|
||||||
|
MATH_FUNCTIONS = {
|
||||||
|
"sum": _variadic_sum,
|
||||||
|
"min": min,
|
||||||
|
"max": max,
|
||||||
|
"abs": abs,
|
||||||
|
"round": round,
|
||||||
|
"pow": _safe_pow,
|
||||||
|
"sqrt": math.sqrt,
|
||||||
|
"ceil": math.ceil,
|
||||||
|
"floor": math.floor,
|
||||||
|
"log": math.log,
|
||||||
|
"log2": math.log2,
|
||||||
|
"log10": math.log10,
|
||||||
|
"sin": math.sin,
|
||||||
|
"cos": math.cos,
|
||||||
|
"tan": math.tan,
|
||||||
|
"int": int,
|
||||||
|
"float": float,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MathExpressionNode(io.ComfyNode):
|
||||||
|
"""Evaluates a math expression against dynamically-grown inputs."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
autogrow = io.Autogrow.TemplateNames(
|
||||||
|
input=io.MultiType.Input("value", [io.Float, io.Int]),
|
||||||
|
names=list(string.ascii_lowercase),
|
||||||
|
min=1,
|
||||||
|
)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComfyMathExpression",
|
||||||
|
display_name="Math Expression",
|
||||||
|
category="math",
|
||||||
|
search_aliases=[
|
||||||
|
"expression", "formula", "calculate", "calculator",
|
||||||
|
"eval", "math",
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("expression", default="a + b", multiline=True),
|
||||||
|
io.Autogrow.Input("values", template=autogrow),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Float.Output(display_name="FLOAT"),
|
||||||
|
io.Int.Output(display_name="INT"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls, expression: str, values: io.Autogrow.Type
|
||||||
|
) -> io.NodeOutput:
|
||||||
|
if not expression.strip():
|
||||||
|
raise ValueError("Expression cannot be empty.")
|
||||||
|
|
||||||
|
context: dict = dict(values)
|
||||||
|
context["values"] = list(values.values())
|
||||||
|
|
||||||
|
result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
|
||||||
|
# bool check must come first because bool is a subclass of int in Python
|
||||||
|
if isinstance(result, bool) or not isinstance(result, (int, float)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Math Expression '{expression}' must evaluate to a numeric result, "
|
||||||
|
f"got {type(result).__name__}: {result!r}"
|
||||||
|
)
|
||||||
|
if not math.isfinite(result):
|
||||||
|
raise ValueError(
|
||||||
|
f"Math Expression '{expression}' produced a non-finite result: {result}"
|
||||||
|
)
|
||||||
|
return io.NodeOutput(float(result), int(result))
|
||||||
|
|
||||||
|
|
||||||
|
class MathExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [MathExpressionNode]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> MathExtension:
|
||||||
|
return MathExtension()
|
||||||
1
nodes.py
1
nodes.py
@ -2449,6 +2449,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_replacements.py",
|
"nodes_replacements.py",
|
||||||
"nodes_nag.py",
|
"nodes_nag.py",
|
||||||
"nodes_sdpose.py",
|
"nodes_sdpose.py",
|
||||||
|
"nodes_math.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -24,6 +24,7 @@ av>=14.2.0
|
|||||||
comfy-kitchen>=0.2.7
|
comfy-kitchen>=0.2.7
|
||||||
comfy-aimdo>=0.2.7
|
comfy-aimdo>=0.2.7
|
||||||
requests
|
requests
|
||||||
|
simpleeval>=1.0
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
|
|||||||
197
tests-unit/comfy_extras_test/nodes_math_test.py
Normal file
197
tests-unit/comfy_extras_test/nodes_math_test.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from collections import OrderedDict
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_nodes = MagicMock()
|
||||||
|
mock_nodes.MAX_RESOLUTION = 16384
|
||||||
|
mock_server = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||||
|
from comfy_extras.nodes_math import MathExpressionNode
|
||||||
|
|
||||||
|
|
||||||
|
class TestMathExpressionExecute:
|
||||||
|
@staticmethod
|
||||||
|
def _exec(expression: str, **kwargs) -> object:
|
||||||
|
values = OrderedDict(kwargs)
|
||||||
|
return MathExpressionNode.execute(expression, values)
|
||||||
|
|
||||||
|
def test_addition(self):
|
||||||
|
result = self._exec("a + b", a=3, b=4)
|
||||||
|
assert result[0] == 7.0
|
||||||
|
assert result[1] == 7
|
||||||
|
|
||||||
|
def test_subtraction(self):
|
||||||
|
result = self._exec("a - b", a=10, b=3)
|
||||||
|
assert result[0] == 7.0
|
||||||
|
assert result[1] == 7
|
||||||
|
|
||||||
|
def test_multiplication(self):
|
||||||
|
result = self._exec("a * b", a=3, b=5)
|
||||||
|
assert result[0] == 15.0
|
||||||
|
assert result[1] == 15
|
||||||
|
|
||||||
|
def test_division(self):
|
||||||
|
result = self._exec("a / b", a=10, b=4)
|
||||||
|
assert result[0] == 2.5
|
||||||
|
assert result[1] == 2
|
||||||
|
|
||||||
|
def test_single_input(self):
|
||||||
|
result = self._exec("a * 2", a=5)
|
||||||
|
assert result[0] == 10.0
|
||||||
|
assert result[1] == 10
|
||||||
|
|
||||||
|
def test_three_inputs(self):
|
||||||
|
result = self._exec("a + b + c", a=1, b=2, c=3)
|
||||||
|
assert result[0] == 6.0
|
||||||
|
assert result[1] == 6
|
||||||
|
|
||||||
|
def test_float_inputs(self):
|
||||||
|
result = self._exec("a + b", a=1.5, b=2.5)
|
||||||
|
assert result[0] == 4.0
|
||||||
|
assert result[1] == 4
|
||||||
|
|
||||||
|
def test_mixed_int_float_inputs(self):
|
||||||
|
result = self._exec("a * b", a=1024, b=1.5)
|
||||||
|
assert result[0] == 1536.0
|
||||||
|
assert result[1] == 1536
|
||||||
|
|
||||||
|
def test_mixed_resolution_scale(self):
|
||||||
|
result = self._exec("a * b", a=512, b=0.75)
|
||||||
|
assert result[0] == 384.0
|
||||||
|
assert result[1] == 384
|
||||||
|
|
||||||
|
def test_sum_values_array(self):
|
||||||
|
result = self._exec("sum(values)", a=1, b=2, c=3)
|
||||||
|
assert result[0] == 6.0
|
||||||
|
|
||||||
|
def test_sum_variadic(self):
|
||||||
|
result = self._exec("sum(a, b, c)", a=1, b=2, c=3)
|
||||||
|
assert result[0] == 6.0
|
||||||
|
|
||||||
|
def test_min_values(self):
|
||||||
|
result = self._exec("min(values)", a=5, b=2, c=8)
|
||||||
|
assert result[0] == 2.0
|
||||||
|
|
||||||
|
def test_max_values(self):
|
||||||
|
result = self._exec("max(values)", a=5, b=2, c=8)
|
||||||
|
assert result[0] == 8.0
|
||||||
|
|
||||||
|
def test_abs_function(self):
|
||||||
|
result = self._exec("abs(a)", a=-7)
|
||||||
|
assert result[0] == 7.0
|
||||||
|
assert result[1] == 7
|
||||||
|
|
||||||
|
def test_sqrt(self):
|
||||||
|
result = self._exec("sqrt(a)", a=16)
|
||||||
|
assert result[0] == 4.0
|
||||||
|
assert result[1] == 4
|
||||||
|
|
||||||
|
def test_ceil(self):
|
||||||
|
result = self._exec("ceil(a)", a=2.3)
|
||||||
|
assert result[0] == 3.0
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_floor(self):
|
||||||
|
result = self._exec("floor(a)", a=2.7)
|
||||||
|
assert result[0] == 2.0
|
||||||
|
assert result[1] == 2
|
||||||
|
|
||||||
|
def test_sin(self):
|
||||||
|
result = self._exec("sin(a)", a=0)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
|
||||||
|
def test_log10(self):
|
||||||
|
result = self._exec("log10(a)", a=100)
|
||||||
|
assert result[0] == 2.0
|
||||||
|
assert result[1] == 2
|
||||||
|
|
||||||
|
def test_float_output_type(self):
|
||||||
|
result = self._exec("a + b", a=1, b=2)
|
||||||
|
assert isinstance(result[0], float)
|
||||||
|
|
||||||
|
def test_int_output_type(self):
|
||||||
|
result = self._exec("a + b", a=1, b=2)
|
||||||
|
assert isinstance(result[1], int)
|
||||||
|
|
||||||
|
def test_non_numeric_result_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="must evaluate to a numeric result"):
|
||||||
|
self._exec("'hello'", a=42)
|
||||||
|
|
||||||
|
def test_undefined_function_raises(self):
|
||||||
|
with pytest.raises(Exception, match="not defined"):
|
||||||
|
self._exec("str(a)", a=42)
|
||||||
|
|
||||||
|
def test_boolean_result_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="got bool"):
|
||||||
|
self._exec("a > b", a=5, b=3)
|
||||||
|
|
||||||
|
def test_empty_expression_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Expression cannot be empty"):
|
||||||
|
self._exec("", a=1)
|
||||||
|
|
||||||
|
def test_whitespace_only_expression_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Expression cannot be empty"):
|
||||||
|
self._exec(" ", a=1)
|
||||||
|
|
||||||
|
# --- Missing function coverage (round, pow, log, log2, cos, tan) ---
|
||||||
|
|
||||||
|
def test_round(self):
|
||||||
|
result = self._exec("round(a)", a=2.7)
|
||||||
|
assert result[0] == 3.0
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_round_with_ndigits(self):
|
||||||
|
result = self._exec("round(a, 2)", a=3.14159)
|
||||||
|
assert result[0] == pytest.approx(3.14)
|
||||||
|
|
||||||
|
def test_pow(self):
|
||||||
|
result = self._exec("pow(a, b)", a=2, b=10)
|
||||||
|
assert result[0] == 1024.0
|
||||||
|
assert result[1] == 1024
|
||||||
|
|
||||||
|
def test_log(self):
|
||||||
|
result = self._exec("log(a)", a=math.e)
|
||||||
|
assert result[0] == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_log2(self):
|
||||||
|
result = self._exec("log2(a)", a=8)
|
||||||
|
assert result[0] == pytest.approx(3.0)
|
||||||
|
|
||||||
|
def test_cos(self):
|
||||||
|
result = self._exec("cos(a)", a=0)
|
||||||
|
assert result[0] == 1.0
|
||||||
|
|
||||||
|
def test_tan(self):
|
||||||
|
result = self._exec("tan(a)", a=0)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
|
||||||
|
# --- int/float converter functions ---
|
||||||
|
|
||||||
|
def test_int_converter(self):
|
||||||
|
result = self._exec("int(a / b)", a=7, b=2)
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_float_converter(self):
|
||||||
|
result = self._exec("float(a)", a=5)
|
||||||
|
assert result[0] == 5.0
|
||||||
|
|
||||||
|
# --- Error path tests ---
|
||||||
|
|
||||||
|
def test_division_by_zero_raises(self):
|
||||||
|
with pytest.raises(ZeroDivisionError):
|
||||||
|
self._exec("a / b", a=1, b=0)
|
||||||
|
|
||||||
|
def test_sqrt_negative_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="math domain error"):
|
||||||
|
self._exec("sqrt(a)", a=-1)
|
||||||
|
|
||||||
|
def test_overflow_inf_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite result"):
|
||||||
|
self._exec("a * b", a=1e308, b=10)
|
||||||
|
|
||||||
|
def test_pow_huge_exponent_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Exponent .* exceeds maximum"):
|
||||||
|
self._exec("pow(a, b)", a=10, b=10000000)
|
||||||
Loading…
Reference in New Issue
Block a user