mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-05 19:12:41 +08:00
Compare commits
15 Commits
1dddbbf985
...
1cbfaed155
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cbfaed155 | ||
|
|
5f62440fbb | ||
|
|
ac91c340f4 | ||
|
|
2db3b0ff90 | ||
|
|
6516ab335d | ||
|
|
ad53e78f11 | ||
|
|
29011ba87e | ||
|
|
cd4985e2f3 | ||
|
|
bfe31d0b9d | ||
|
|
dcf686857c | ||
|
|
17eed38750 | ||
|
|
f4623c0e1b | ||
|
|
5e4bbca1ad | ||
|
|
e17571d9be | ||
|
|
6540aa0400 |
6
.github/workflows/release-stable-all.yml
vendored
6
.github/workflows/release-stable-all.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "9"
|
||||
python_patch: "11"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
@ -65,11 +65,11 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
name: "Release AMD ROCm 7.2"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm711"
|
||||
cache_tag: "rocm72"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
|
||||
@ -479,10 +479,12 @@ class WanVAE(nn.Module):
|
||||
|
||||
def encode(self, x):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@ -502,10 +504,11 @@ class WanVAE(nn.Module):
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
# z: [b,c,t,h,w]
|
||||
|
||||
iter_ = z.shape[2]
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
|
||||
@ -13,17 +13,6 @@ class Text2ImageTaskCreationRequest(BaseModel):
|
||||
watermark: bool | None = Field(False)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: str | None = Field("adaptive")
|
||||
seed: int | None = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(False)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
122
comfy_api_nodes/apis/magnific.py
Normal file
122
comfy_api_nodes/apis/magnific.py
Normal file
@ -0,0 +1,122 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class InputPortraitMode(TypedDict):
|
||||
portrait_mode: str
|
||||
portrait_style: str
|
||||
portrait_beautifier: str
|
||||
|
||||
|
||||
class InputAdvancedSettings(TypedDict):
|
||||
advanced_settings: str
|
||||
whites: int
|
||||
blacks: int
|
||||
brightness: int
|
||||
contrast: int
|
||||
saturation: int
|
||||
engine: str
|
||||
transfer_light_a: str
|
||||
transfer_light_b: str
|
||||
fixed_generation: bool
|
||||
|
||||
|
||||
class InputSkinEnhancerMode(TypedDict):
|
||||
mode: str
|
||||
skin_detail: int
|
||||
optimized_for: str
|
||||
|
||||
|
||||
class ImageUpscalerCreativeRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
scale_factor: str = Field(...)
|
||||
optimized_for: str = Field(...)
|
||||
prompt: str | None = Field(None)
|
||||
creativity: int = Field(...)
|
||||
hdr: int = Field(...)
|
||||
resemblance: int = Field(...)
|
||||
fractality: int = Field(...)
|
||||
engine: str = Field(...)
|
||||
|
||||
|
||||
class ImageUpscalerPrecisionV2Request(BaseModel):
|
||||
image: str = Field(...)
|
||||
sharpen: int = Field(...)
|
||||
smart_grain: int = Field(...)
|
||||
ultra_detail: int = Field(...)
|
||||
flavor: str = Field(...)
|
||||
scale_factor: int = Field(...)
|
||||
|
||||
|
||||
class ImageRelightAdvancedSettingsRequest(BaseModel):
|
||||
whites: int = Field(...)
|
||||
blacks: int = Field(...)
|
||||
brightness: int = Field(...)
|
||||
contrast: int = Field(...)
|
||||
saturation: int = Field(...)
|
||||
engine: str = Field(...)
|
||||
transfer_light_a: str = Field(...)
|
||||
transfer_light_b: str = Field(...)
|
||||
fixed_generation: bool = Field(...)
|
||||
|
||||
|
||||
class ImageRelightRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
prompt: str | None = Field(None)
|
||||
transfer_light_from_reference_image: str | None = Field(None)
|
||||
light_transfer_strength: int = Field(...)
|
||||
interpolate_from_original: bool = Field(...)
|
||||
change_background: bool = Field(...)
|
||||
style: str = Field(...)
|
||||
preserve_details: bool = Field(...)
|
||||
advanced_settings: ImageRelightAdvancedSettingsRequest | None = Field(...)
|
||||
|
||||
|
||||
class ImageStyleTransferRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
reference_image: str = Field(...)
|
||||
prompt: str | None = Field(None)
|
||||
style_strength: int = Field(...)
|
||||
structure_strength: int = Field(...)
|
||||
is_portrait: bool = Field(...)
|
||||
portrait_style: str | None = Field(...)
|
||||
portrait_beautifier: str | None = Field(...)
|
||||
flavor: str = Field(...)
|
||||
engine: str = Field(...)
|
||||
fixed_generation: bool = Field(...)
|
||||
|
||||
|
||||
class ImageSkinEnhancerCreativeRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
sharpen: int = Field(...)
|
||||
smart_grain: int = Field(...)
|
||||
|
||||
|
||||
class ImageSkinEnhancerFaithfulRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
sharpen: int = Field(...)
|
||||
smart_grain: int = Field(...)
|
||||
skin_detail: int = Field(...)
|
||||
|
||||
|
||||
class ImageSkinEnhancerFlexibleRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
sharpen: int = Field(...)
|
||||
smart_grain: int = Field(...)
|
||||
optimized_for: str = Field(...)
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
"""Unified response model that handles both wrapped and unwrapped API responses."""
|
||||
|
||||
task_id: str = Field(...)
|
||||
status: str = Field(validation_alias=AliasChoices("status", "task_status"))
|
||||
generated: list[str] | None = Field(None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def unwrap_data(cls, values: dict) -> dict:
|
||||
if "data" in values and isinstance(values["data"], dict):
|
||||
return values["data"]
|
||||
return values
|
||||
@ -9,7 +9,6 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedream4Options,
|
||||
@ -174,99 +173,6 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
|
||||
|
||||
class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageEditNode",
|
||||
display_name="ByteDance Image Edit",
|
||||
category="api node/image/ByteDance",
|
||||
description="Edit images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The base image to edit",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Instruction to edit image",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation",
|
||||
optional=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"guidance_scale",
|
||||
default=5.5,
|
||||
min=1.0,
|
||||
max=10.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Higher value makes the image follow the prompt more closely",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
guidance_scale: float,
|
||||
watermark: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
validate_image_aspect_ratio(image, (1, 3), (3, 1))
|
||||
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
|
||||
payload = Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
image=source_url,
|
||||
seed=seed,
|
||||
guidance_scale=guidance_scale,
|
||||
watermark=watermark,
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
|
||||
|
||||
class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -1101,7 +1007,6 @@ class ByteDanceExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ByteDanceImageNode,
|
||||
ByteDanceImageEditNode,
|
||||
ByteDanceSeedreamNode,
|
||||
ByteDanceTextToVideoNode,
|
||||
ByteDanceImageToVideoNode,
|
||||
|
||||
889
comfy_api_nodes/nodes_magnific.py
Normal file
889
comfy_api_nodes/nodes_magnific.py
Normal file
@ -0,0 +1,889 @@
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.magnific import (
|
||||
ImageRelightAdvancedSettingsRequest,
|
||||
ImageRelightRequest,
|
||||
ImageSkinEnhancerCreativeRequest,
|
||||
ImageSkinEnhancerFaithfulRequest,
|
||||
ImageSkinEnhancerFlexibleRequest,
|
||||
ImageStyleTransferRequest,
|
||||
ImageUpscalerCreativeRequest,
|
||||
ImageUpscalerPrecisionV2Request,
|
||||
InputAdvancedSettings,
|
||||
InputPortraitMode,
|
||||
InputSkinEnhancerMode,
|
||||
TaskResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
downscale_image_tensor,
|
||||
get_image_dimensions,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
)
|
||||
|
||||
|
||||
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageUpscalerCreativeNode",
|
||||
display_name="Magnific Image Upscale (Creative)",
|
||||
category="api node/image/Magnific",
|
||||
description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. "
|
||||
"Maximum output: 25.3 megapixels.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]),
|
||||
IO.Combo.Input(
|
||||
"optimized_for",
|
||||
options=[
|
||||
"standard",
|
||||
"soft_portraits",
|
||||
"hard_portraits",
|
||||
"art_n_illustration",
|
||||
"videogame_assets",
|
||||
"nature_n_landscapes",
|
||||
"films_n_photography",
|
||||
"3d_renders",
|
||||
"science_fiction_n_horror",
|
||||
],
|
||||
),
|
||||
IO.Int.Input("creativity", min=-10, max=10, default=0, display_mode=IO.NumberDisplay.slider),
|
||||
IO.Int.Input(
|
||||
"hdr",
|
||||
min=-10,
|
||||
max=10,
|
||||
default=0,
|
||||
tooltip="The level of definition and detail.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"resemblance",
|
||||
min=-10,
|
||||
max=10,
|
||||
default=0,
|
||||
tooltip="The level of resemblance to the original image.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"fractality",
|
||||
min=-10,
|
||||
max=10,
|
||||
default=0,
|
||||
tooltip="The strength of the prompt and intricacy per square pixel.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"engine",
|
||||
options=["automatic", "magnific_illusio", "magnific_sharpy", "magnific_sparkle"],
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
tooltip="Automatically downscale input image if output would exceed maximum pixel limit.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
expr="""
|
||||
(
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
scale_factor: str,
|
||||
optimized_for: str,
|
||||
creativity: int,
|
||||
hdr: int,
|
||||
resemblance: int,
|
||||
fractality: int,
|
||||
engine: str,
|
||||
auto_downscale: bool,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False)
|
||||
validate_image_dimensions(image, min_height=160, min_width=160)
|
||||
|
||||
max_output_pixels = 25_300_000
|
||||
height, width = get_image_dimensions(image)
|
||||
requested_scale = int(scale_factor.rstrip("x"))
|
||||
output_pixels = height * width * requested_scale * requested_scale
|
||||
|
||||
if output_pixels > max_output_pixels:
|
||||
if auto_downscale:
|
||||
# Find optimal scale factor that doesn't require >2x downscale.
|
||||
# Server upscales in 2x steps, so aggressive downscaling degrades quality.
|
||||
input_pixels = width * height
|
||||
scale = 2
|
||||
max_input_pixels = max_output_pixels // 4
|
||||
for candidate in [16, 8, 4, 2]:
|
||||
if candidate > requested_scale:
|
||||
continue
|
||||
scale_output_pixels = input_pixels * candidate * candidate
|
||||
if scale_output_pixels <= max_output_pixels:
|
||||
scale = candidate
|
||||
max_input_pixels = None
|
||||
break
|
||||
downscale_ratio = math.sqrt(scale_output_pixels / max_output_pixels)
|
||||
if downscale_ratio <= 2.0:
|
||||
scale = candidate
|
||||
max_input_pixels = max_output_pixels // (candidate * candidate)
|
||||
break
|
||||
|
||||
if max_input_pixels is not None:
|
||||
image = downscale_image_tensor(image, total_pixels=max_input_pixels)
|
||||
scale_factor = f"{scale}x"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Output size ({width * requested_scale}x{height * requested_scale} = {output_pixels:,} pixels) "
|
||||
f"exceeds maximum allowed size of {max_output_pixels:,} pixels. "
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
|
||||
response_model=TaskResponse,
|
||||
data=ImageUpscalerCreativeRequest(
|
||||
image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0],
|
||||
scale_factor=scale_factor,
|
||||
optimized_for=optimized_for,
|
||||
creativity=creativity,
|
||||
hdr=hdr,
|
||||
resemblance=resemblance,
|
||||
fractality=fractality,
|
||||
engine=engine,
|
||||
prompt=prompt if prompt else None,
|
||||
),
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
|
||||
class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageUpscalerPreciseV2Node",
|
||||
display_name="Magnific Image Upscale (Precise V2)",
|
||||
category="api node/image/Magnific",
|
||||
description="High-fidelity upscaling with fine control over sharpness, grain, and detail. "
|
||||
"Maximum output: 10060×10060 pixels.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]),
|
||||
IO.Combo.Input(
|
||||
"flavor",
|
||||
options=["sublime", "photo", "photo_denoiser"],
|
||||
tooltip="Processing style: "
|
||||
"sublime for general use, photo for photographs, photo_denoiser for noisy photos.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"sharpen",
|
||||
min=0,
|
||||
max=100,
|
||||
default=7,
|
||||
tooltip="Image sharpness intensity. Higher values increase edge definition and clarity.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"smart_grain",
|
||||
min=0,
|
||||
max=100,
|
||||
default=7,
|
||||
tooltip="Intelligent grain/texture enhancement to prevent the image from "
|
||||
"looking too smooth or artificial.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"ultra_detail",
|
||||
min=0,
|
||||
max=100,
|
||||
default=30,
|
||||
tooltip="Controls fine detail, textures, and micro-details added during upscaling.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
tooltip="Automatically downscale input image if output would exceed maximum resolution.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
expr="""
|
||||
(
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
scale_factor: str,
|
||||
flavor: str,
|
||||
sharpen: int,
|
||||
smart_grain: int,
|
||||
ultra_detail: int,
|
||||
auto_downscale: bool,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False)
|
||||
validate_image_dimensions(image, min_height=160, min_width=160)
|
||||
|
||||
max_output_dimension = 10060
|
||||
height, width = get_image_dimensions(image)
|
||||
requested_scale = int(scale_factor.strip("x"))
|
||||
output_width = width * requested_scale
|
||||
output_height = height * requested_scale
|
||||
|
||||
if output_width > max_output_dimension or output_height > max_output_dimension:
|
||||
if auto_downscale:
|
||||
# Find optimal scale factor that doesn't require >2x downscale.
|
||||
# Server upscales in 2x steps, so aggressive downscaling degrades quality.
|
||||
max_dim = max(width, height)
|
||||
scale = 2
|
||||
max_input_dim = max_output_dimension // 2
|
||||
scale_ratio = max_input_dim / max_dim
|
||||
max_input_pixels = int(width * height * scale_ratio * scale_ratio)
|
||||
for candidate in [16, 8, 4, 2]:
|
||||
if candidate > requested_scale:
|
||||
continue
|
||||
output_dim = max_dim * candidate
|
||||
if output_dim <= max_output_dimension:
|
||||
scale = candidate
|
||||
max_input_pixels = None
|
||||
break
|
||||
downscale_ratio = output_dim / max_output_dimension
|
||||
if downscale_ratio <= 2.0:
|
||||
scale = candidate
|
||||
max_input_dim = max_output_dimension // candidate
|
||||
scale_ratio = max_input_dim / max_dim
|
||||
max_input_pixels = int(width * height * scale_ratio * scale_ratio)
|
||||
break
|
||||
|
||||
if max_input_pixels is not None:
|
||||
image = downscale_image_tensor(image, total_pixels=max_input_pixels)
|
||||
requested_scale = scale
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Output dimensions ({output_width}x{output_height}) exceed maximum allowed "
|
||||
f"resolution of {max_output_dimension}x{max_output_dimension} pixels. "
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
|
||||
response_model=TaskResponse,
|
||||
data=ImageUpscalerPrecisionV2Request(
|
||||
image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0],
|
||||
scale_factor=requested_scale,
|
||||
flavor=flavor,
|
||||
sharpen=sharpen,
|
||||
smart_grain=smart_grain,
|
||||
ultra_detail=ultra_detail,
|
||||
),
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
|
||||
class MagnificImageStyleTransferNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageStyleTransferNode",
|
||||
display_name="Magnific Image Style Transfer",
|
||||
category="api node/image/Magnific",
|
||||
description="Transfer the style from a reference image to your input image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to apply style transfer to."),
|
||||
IO.Image.Input("reference_image", tooltip="The reference image to extract style from."),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Int.Input(
|
||||
"style_strength",
|
||||
min=0,
|
||||
max=100,
|
||||
default=100,
|
||||
tooltip="Percentage of style strength.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"structure_strength",
|
||||
min=0,
|
||||
max=100,
|
||||
default=50,
|
||||
tooltip="Maintains the structure of the original image.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"flavor",
|
||||
options=["faithful", "gen_z", "psychedelia", "detaily", "clear", "donotstyle", "donotstyle_sharp"],
|
||||
tooltip="Style transfer flavor.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"engine",
|
||||
options=[
|
||||
"balanced",
|
||||
"definio",
|
||||
"illusio",
|
||||
"3d_cartoon",
|
||||
"colorful_anime",
|
||||
"caricature",
|
||||
"real",
|
||||
"super_real",
|
||||
"softy",
|
||||
],
|
||||
tooltip="Processing engine selection.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"portrait_mode",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("disabled", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"enabled",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"portrait_style",
|
||||
options=["standard", "pop", "super_pop"],
|
||||
tooltip="Visual style applied to portrait images.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"portrait_beautifier",
|
||||
options=["none", "beautify_face", "beautify_face_max"],
|
||||
tooltip="Facial beautification intensity on portraits.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Enable portrait mode for facial enhancements.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"fixed_generation",
|
||||
default=True,
|
||||
tooltip="When disabled, expect each generation to introduce a degree of randomness, "
|
||||
"leading to more diverse outcomes.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.11}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
reference_image: Input.Image,
|
||||
prompt: str,
|
||||
style_strength: int,
|
||||
structure_strength: int,
|
||||
flavor: str,
|
||||
engine: str,
|
||||
portrait_mode: InputPortraitMode,
|
||||
fixed_generation: bool,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
if get_number_of_images(reference_image) != 1:
|
||||
raise ValueError("Exactly one reference image is required.")
|
||||
validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False)
|
||||
validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False)
|
||||
validate_image_dimensions(image, min_height=160, min_width=160)
|
||||
validate_image_dimensions(reference_image, min_height=160, min_width=160)
|
||||
|
||||
is_portrait = portrait_mode["portrait_mode"] == "enabled"
|
||||
portrait_style = portrait_mode.get("portrait_style", "standard")
|
||||
portrait_beautifier = portrait_mode.get("portrait_beautifier", "none")
|
||||
|
||||
uploaded_urls = await upload_images_to_comfyapi(cls, [image, reference_image], max_images=2)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-style-transfer", method="POST"),
|
||||
response_model=TaskResponse,
|
||||
data=ImageStyleTransferRequest(
|
||||
image=uploaded_urls[0],
|
||||
reference_image=uploaded_urls[1],
|
||||
prompt=prompt if prompt else None,
|
||||
style_strength=style_strength,
|
||||
structure_strength=structure_strength,
|
||||
is_portrait=is_portrait,
|
||||
portrait_style=portrait_style if is_portrait else None,
|
||||
portrait_beautifier=portrait_beautifier if is_portrait and portrait_beautifier != "none" else None,
|
||||
flavor=flavor,
|
||||
engine=engine,
|
||||
fixed_generation=fixed_generation,
|
||||
),
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-style-transfer/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
|
||||
class MagnificImageRelightNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageRelightNode",
|
||||
display_name="Magnific Image Relight",
|
||||
category="api node/image/Magnific",
|
||||
description="Relight an image with lighting adjustments and optional reference-based light transfer.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to relight."),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Descriptive guidance for lighting. Supports emphasis notation (1-1.4).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"light_transfer_strength",
|
||||
min=0,
|
||||
max=100,
|
||||
default=100,
|
||||
tooltip="Intensity of light transfer application.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style",
|
||||
options=[
|
||||
"standard",
|
||||
"darker_but_realistic",
|
||||
"clean",
|
||||
"smooth",
|
||||
"brighter",
|
||||
"contrasted_n_hdr",
|
||||
"just_composition",
|
||||
],
|
||||
tooltip="Stylistic output preference.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"interpolate_from_original",
|
||||
default=False,
|
||||
tooltip="Restricts generation freedom to match original more closely.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"change_background",
|
||||
default=True,
|
||||
tooltip="Modifies background based on prompt/reference.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"preserve_details",
|
||||
default=True,
|
||||
tooltip="Maintains texture and fine details from original.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"advanced_settings",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("disabled", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"enabled",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"whites",
|
||||
min=0,
|
||||
max=100,
|
||||
default=50,
|
||||
tooltip="Adjusts the brightest tones in the image.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"blacks",
|
||||
min=0,
|
||||
max=100,
|
||||
default=50,
|
||||
tooltip="Adjusts the darkest tones in the image.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"brightness",
|
||||
min=0,
|
||||
max=100,
|
||||
default=50,
|
||||
tooltip="Overall brightness adjustment.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"contrast",
|
||||
min=0,
|
||||
max=100,
|
||||
default=50,
|
||||
tooltip="Contrast adjustment.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"saturation",
|
||||
min=0,
|
||||
max=100,
|
||||
default=50,
|
||||
tooltip="Color saturation adjustment.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"engine",
|
||||
options=[
|
||||
"automatic",
|
||||
"balanced",
|
||||
"cool",
|
||||
"real",
|
||||
"illusio",
|
||||
"fairy",
|
||||
"colorful_anime",
|
||||
"hard_transform",
|
||||
"softy",
|
||||
],
|
||||
tooltip="Processing engine selection.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"transfer_light_a",
|
||||
options=["automatic", "low", "medium", "normal", "high", "high_on_faces"],
|
||||
tooltip="The intensity of light transfer.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"transfer_light_b",
|
||||
options=[
|
||||
"automatic",
|
||||
"composition",
|
||||
"straight",
|
||||
"smooth_in",
|
||||
"smooth_out",
|
||||
"smooth_both",
|
||||
"reverse_both",
|
||||
"soft_in",
|
||||
"soft_out",
|
||||
"soft_mid",
|
||||
# "strong_mid", # Commented out because requests fail when this is set.
|
||||
"style_shift",
|
||||
"strong_shift",
|
||||
],
|
||||
tooltip="Also modifies light transfer intensity. "
|
||||
"Can be combined with the previous control for varied effects.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"fixed_generation",
|
||||
default=True,
|
||||
tooltip="Ensures consistent output with the same settings.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Fine-tuning options for advanced lighting control.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"reference_image",
|
||||
optional=True,
|
||||
tooltip="Optional reference image to transfer lighting from.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.11}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
light_transfer_strength: int,
|
||||
style: str,
|
||||
interpolate_from_original: bool,
|
||||
change_background: bool,
|
||||
preserve_details: bool,
|
||||
advanced_settings: InputAdvancedSettings,
|
||||
reference_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
if reference_image is not None and get_number_of_images(reference_image) != 1:
|
||||
raise ValueError("Exactly one reference image is required.")
|
||||
validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False)
|
||||
validate_image_dimensions(image, min_height=160, min_width=160)
|
||||
if reference_image is not None:
|
||||
validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False)
|
||||
validate_image_dimensions(reference_image, min_height=160, min_width=160)
|
||||
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||
reference_url = None
|
||||
if reference_image is not None:
|
||||
reference_url = (await upload_images_to_comfyapi(cls, reference_image, max_images=1))[0]
|
||||
|
||||
adv_settings = None
|
||||
if advanced_settings["advanced_settings"] == "enabled":
|
||||
adv_settings = ImageRelightAdvancedSettingsRequest(
|
||||
whites=advanced_settings["whites"],
|
||||
blacks=advanced_settings["blacks"],
|
||||
brightness=advanced_settings["brightness"],
|
||||
contrast=advanced_settings["contrast"],
|
||||
saturation=advanced_settings["saturation"],
|
||||
engine=advanced_settings["engine"],
|
||||
transfer_light_a=advanced_settings["transfer_light_a"],
|
||||
transfer_light_b=advanced_settings["transfer_light_b"],
|
||||
fixed_generation=advanced_settings["fixed_generation"],
|
||||
)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-relight", method="POST"),
|
||||
response_model=TaskResponse,
|
||||
data=ImageRelightRequest(
|
||||
image=image_url,
|
||||
prompt=prompt if prompt else None,
|
||||
transfer_light_from_reference_image=reference_url,
|
||||
light_transfer_strength=light_transfer_strength,
|
||||
interpolate_from_original=interpolate_from_original,
|
||||
change_background=change_background,
|
||||
style=style,
|
||||
preserve_details=preserve_details,
|
||||
advanced_settings=adv_settings,
|
||||
),
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-relight/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
|
||||
class MagnificImageSkinEnhancerNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageSkinEnhancerNode",
|
||||
display_name="Magnific Image Skin Enhancer",
|
||||
category="api node/image/Magnific",
|
||||
description="Skin enhancement for portraits with multiple processing modes.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The portrait image to enhance."),
|
||||
IO.Int.Input(
|
||||
"sharpen",
|
||||
min=0,
|
||||
max=100,
|
||||
default=0,
|
||||
tooltip="Sharpening intensity level.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"smart_grain",
|
||||
min=0,
|
||||
max=100,
|
||||
default=2,
|
||||
tooltip="Smart grain intensity level.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"mode",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("creative", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"faithful",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"skin_detail",
|
||||
min=0,
|
||||
max=100,
|
||||
default=80,
|
||||
tooltip="Skin detail enhancement level.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"flexible",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"optimized_for",
|
||||
options=[
|
||||
"enhance_skin",
|
||||
"improve_lighting",
|
||||
"enhance_everything",
|
||||
"transform_to_real",
|
||||
"no_make_up",
|
||||
],
|
||||
tooltip="Enhancement optimization target.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Processing mode: creative for artistic enhancement, "
|
||||
"faithful for preserving original appearance, "
|
||||
"flexible for targeted optimization.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
|
||||
expr="""
|
||||
(
|
||||
$rates := {"creative": 0.29, "faithful": 0.37, "flexible": 0.45};
|
||||
{"type":"usd","usd": $lookup($rates, widgets.mode)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
sharpen: int,
|
||||
smart_grain: int,
|
||||
mode: InputSkinEnhancerMode,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False)
|
||||
validate_image_dimensions(image, min_height=160, min_width=160)
|
||||
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=4096 * 4096))[0]
|
||||
selected_mode = mode["mode"]
|
||||
|
||||
if selected_mode == "creative":
|
||||
endpoint = "creative"
|
||||
data = ImageSkinEnhancerCreativeRequest(
|
||||
image=image_url,
|
||||
sharpen=sharpen,
|
||||
smart_grain=smart_grain,
|
||||
)
|
||||
elif selected_mode == "faithful":
|
||||
endpoint = "faithful"
|
||||
data = ImageSkinEnhancerFaithfulRequest(
|
||||
image=image_url,
|
||||
sharpen=sharpen,
|
||||
smart_grain=smart_grain,
|
||||
skin_detail=mode["skin_detail"],
|
||||
)
|
||||
else: # flexible
|
||||
endpoint = "flexible"
|
||||
data = ImageSkinEnhancerFlexibleRequest(
|
||||
image=image_url,
|
||||
sharpen=sharpen,
|
||||
smart_grain=smart_grain,
|
||||
optimized_for=mode["optimized_for"],
|
||||
)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{endpoint}", method="POST"),
|
||||
response_model=TaskResponse,
|
||||
data=data,
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
|
||||
class MagnificExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
# MagnificImageUpscalerCreativeNode,
|
||||
# MagnificImageUpscalerPreciseV2Node,
|
||||
MagnificImageStyleTransferNode,
|
||||
MagnificImageRelightNode,
|
||||
MagnificImageSkinEnhancerNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MagnificExtension:
|
||||
return MagnificExtension()
|
||||
@ -56,15 +56,14 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
@ -79,13 +78,14 @@ def tensor_to_bytesio(
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
def tensor_to_pil(image: torch.Tensor, total_pixels: int | None = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
|
||||
if total_pixels is not None:
|
||||
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
@ -93,14 +93,14 @@ def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
@ -161,14 +161,14 @@ def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
|
||||
@ -49,7 +49,7 @@ async def upload_images_to_comfyapi(
|
||||
mime_type: str | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
show_batch_index: bool = True,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
|
||||
316
comfy_execution/cache_provider.py
Normal file
316
comfy_execution/cache_provider.py
Normal file
@ -0,0 +1,316 @@
|
||||
"""
|
||||
External Cache Provider API for distributed caching.
|
||||
|
||||
This module provides a public API for external cache providers, enabling
|
||||
distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods).
|
||||
|
||||
Example usage:
|
||||
from comfy_execution.cache_provider import (
|
||||
CacheProvider, CacheContext, CacheValue, register_cache_provider
|
||||
)
|
||||
|
||||
class MyRedisProvider(CacheProvider):
|
||||
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
# Check Redis/GCS for cached result
|
||||
...
|
||||
|
||||
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
# Store to Redis/GCS (can be async internally)
|
||||
...
|
||||
|
||||
register_cache_provider(MyRedisProvider())
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Tuple, List
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import pickle
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data Classes
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class CacheContext:
|
||||
"""Context passed to provider methods."""
|
||||
prompt_id: str # Current prompt execution ID
|
||||
node_id: str # Node being cached
|
||||
class_type: str # Node class type (e.g., "KSampler")
|
||||
cache_key: Any # Raw cache key (frozenset structure)
|
||||
cache_key_bytes: bytes # SHA256 hash for external storage key
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheValue:
|
||||
"""
|
||||
Value stored/retrieved from external cache.
|
||||
|
||||
Note: UI data is intentionally excluded - it contains pod-local
|
||||
file paths that aren't portable across instances.
|
||||
"""
|
||||
outputs: list # The tensor/value outputs
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Provider Interface
|
||||
# ============================================================
|
||||
|
||||
class CacheProvider(ABC):
|
||||
"""
|
||||
Abstract base class for external cache providers.
|
||||
|
||||
Thread Safety:
|
||||
Providers may be called from multiple threads. Implementations
|
||||
must be thread-safe.
|
||||
|
||||
Error Handling:
|
||||
All methods are wrapped in try/except by the caller. Exceptions
|
||||
are logged but never propagate to break execution.
|
||||
|
||||
Performance Guidelines:
|
||||
- on_lookup: Should complete in <500ms (including network)
|
||||
- on_store: Can be async internally (fire-and-forget)
|
||||
- should_cache: Should be fast (<1ms), called frequently
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
"""
|
||||
Check external storage for cached result.
|
||||
|
||||
Called AFTER local cache miss (local-first for performance).
|
||||
|
||||
Returns:
|
||||
CacheValue if found externally, None otherwise.
|
||||
|
||||
Important:
|
||||
- Return None on any error (don't raise)
|
||||
- Validate data integrity before returning
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
"""
|
||||
Store value to external cache.
|
||||
|
||||
Called AFTER value is stored in local cache.
|
||||
|
||||
Important:
|
||||
- Can be fire-and-forget (async internally)
|
||||
- Should never block execution
|
||||
- Handle serialization failures gracefully
|
||||
"""
|
||||
pass
|
||||
|
||||
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
|
||||
"""
|
||||
Filter which nodes should be externally cached.
|
||||
|
||||
Called before on_lookup (value=None) and on_store (value provided).
|
||||
Return False to skip external caching for this node.
|
||||
|
||||
Common filters:
|
||||
- By class_type: Only expensive nodes (KSampler, VAEDecode)
|
||||
- By size: Skip small values (< 1MB)
|
||||
|
||||
Default: Returns True (cache everything).
|
||||
"""
|
||||
return True
|
||||
|
||||
def on_prompt_start(self, prompt_id: str) -> None:
|
||||
"""Called when prompt execution begins. Optional."""
|
||||
pass
|
||||
|
||||
def on_prompt_end(self, prompt_id: str) -> None:
|
||||
"""Called when prompt execution ends. Optional."""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Provider Registry
|
||||
# ============================================================
|
||||
|
||||
_providers: List[CacheProvider] = []
|
||||
_providers_lock = threading.Lock()
|
||||
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None
|
||||
|
||||
|
||||
def register_cache_provider(provider: CacheProvider) -> None:
|
||||
"""
|
||||
Register an external cache provider.
|
||||
|
||||
Providers are called in registration order. First provider to return
|
||||
a result from on_lookup wins.
|
||||
"""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
if provider in _providers:
|
||||
logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
||||
return
|
||||
_providers.append(provider)
|
||||
_providers_snapshot = None # Invalidate cache
|
||||
logger.info(f"Registered cache provider: {provider.__class__.__name__}")
|
||||
|
||||
|
||||
def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||
"""Remove a previously registered provider."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
try:
|
||||
_providers.remove(provider)
|
||||
_providers_snapshot = None
|
||||
logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
|
||||
except ValueError:
|
||||
logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
||||
|
||||
|
||||
def get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||
"""Get registered providers (cached for performance)."""
|
||||
global _providers_snapshot
|
||||
snapshot = _providers_snapshot
|
||||
if snapshot is not None:
|
||||
return snapshot
|
||||
with _providers_lock:
|
||||
if _providers_snapshot is not None:
|
||||
return _providers_snapshot
|
||||
_providers_snapshot = tuple(_providers)
|
||||
return _providers_snapshot
|
||||
|
||||
|
||||
def has_cache_providers() -> bool:
|
||||
"""Fast check if any providers registered (no lock)."""
|
||||
return bool(_providers)
|
||||
|
||||
|
||||
def clear_cache_providers() -> None:
|
||||
"""Remove all providers. Useful for testing."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
_providers.clear()
|
||||
_providers_snapshot = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Utilities
|
||||
# ============================================================
|
||||
|
||||
def _canonicalize(obj: Any) -> Any:
|
||||
"""
|
||||
Convert an object to a canonical, JSON-serializable form.
|
||||
|
||||
This ensures deterministic ordering regardless of Python's hash randomization,
|
||||
which is critical for cross-pod cache key consistency. Frozensets in particular
|
||||
have non-deterministic iteration order between Python sessions.
|
||||
"""
|
||||
if isinstance(obj, frozenset):
|
||||
# Sort frozenset items for deterministic ordering
|
||||
return ("__frozenset__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, tuple):
|
||||
return ("__tuple__", [_canonicalize(item) for item in obj])
|
||||
elif isinstance(obj, list):
|
||||
return [_canonicalize(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {str(k): _canonicalize(v) for k, v in sorted(obj.items())}
|
||||
elif isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, bytes):
|
||||
return ("__bytes__", obj.hex())
|
||||
elif hasattr(obj, 'value'):
|
||||
# Handle Unhashable class from ComfyUI
|
||||
return ("__unhashable__", _canonicalize(getattr(obj, 'value', None)))
|
||||
else:
|
||||
# For other types, use repr as fallback
|
||||
return ("__repr__", repr(obj))
|
||||
|
||||
|
||||
def serialize_cache_key(cache_key: Any) -> bytes:
|
||||
"""
|
||||
Serialize cache key to bytes for external storage.
|
||||
|
||||
Returns SHA256 hash suitable for Redis/database keys.
|
||||
|
||||
Note: Uses canonicalize + JSON serialization instead of pickle because
|
||||
pickle is NOT deterministic across Python sessions due to hash randomization
|
||||
affecting frozenset iteration order. This is critical for distributed caching
|
||||
where different pods need to compute the same hash for identical inputs.
|
||||
"""
|
||||
try:
|
||||
canonical = _canonicalize(cache_key)
|
||||
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
||||
return hashlib.sha256(json_str.encode('utf-8')).digest()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to serialize cache key: {e}")
|
||||
# Fallback to pickle (non-deterministic but better than nothing)
|
||||
try:
|
||||
serialized = pickle.dumps(cache_key, protocol=4)
|
||||
return hashlib.sha256(serialized).digest()
|
||||
except Exception:
|
||||
return hashlib.sha256(str(id(cache_key)).encode()).digest()
|
||||
|
||||
|
||||
def contains_nan(obj: Any) -> bool:
|
||||
"""
|
||||
Check if cache key contains NaN (indicates uncacheable node).
|
||||
|
||||
NaN != NaN in Python, so local cache never hits. But serialized
|
||||
NaN would match, causing incorrect external hits. Must skip these.
|
||||
"""
|
||||
if isinstance(obj, float):
|
||||
try:
|
||||
return math.isnan(obj)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
if hasattr(obj, 'value'): # Unhashable class
|
||||
val = getattr(obj, 'value', None)
|
||||
if isinstance(val, float):
|
||||
try:
|
||||
return math.isnan(val)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
if isinstance(obj, (frozenset, tuple, list, set)):
|
||||
return any(contains_nan(item) for item in obj)
|
||||
if isinstance(obj, dict):
|
||||
return any(contains_nan(k) or contains_nan(v) for k, v in obj.items())
|
||||
return False
|
||||
|
||||
|
||||
def estimate_value_size(value: CacheValue) -> int:
|
||||
"""Estimate serialized size in bytes. Useful for size-based filtering."""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
return 0
|
||||
|
||||
total = 0
|
||||
|
||||
def estimate(obj):
|
||||
nonlocal total
|
||||
if isinstance(obj, torch.Tensor):
|
||||
total += obj.numel() * obj.element_size()
|
||||
elif isinstance(obj, dict):
|
||||
for v in obj.values():
|
||||
estimate(v)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
estimate(item)
|
||||
|
||||
for output in value.outputs:
|
||||
estimate(output)
|
||||
return total
|
||||
@ -155,6 +155,10 @@ class BasicCache:
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
# External cache provider support
|
||||
self._is_subcache = False
|
||||
self._current_prompt_id = ''
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
@ -201,20 +205,123 @@ class BasicCache:
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
# Notify external providers
|
||||
self._notify_providers_store(node_id, cache_key, value)
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
|
||||
# Check local cache first (fast path)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
else:
|
||||
|
||||
# Check external providers on local miss
|
||||
external_result = self._check_providers_lookup(node_id, cache_key)
|
||||
if external_result is not None:
|
||||
self.cache[cache_key] = external_result # Warm local cache
|
||||
return external_result
|
||||
|
||||
return None
|
||||
|
||||
def _notify_providers_store(self, node_id, cache_key, value):
|
||||
"""Notify external providers of cache store."""
|
||||
from comfy_execution.cache_provider import (
|
||||
has_cache_providers, get_cache_providers,
|
||||
CacheContext, CacheValue,
|
||||
serialize_cache_key, contains_nan, logger
|
||||
)
|
||||
|
||||
# Fast exit conditions
|
||||
if self._is_subcache:
|
||||
return
|
||||
if not has_cache_providers():
|
||||
return
|
||||
if not self._is_cacheable_value(value):
|
||||
return
|
||||
if contains_nan(cache_key):
|
||||
return
|
||||
|
||||
context = CacheContext(
|
||||
prompt_id=self._current_prompt_id,
|
||||
node_id=node_id,
|
||||
class_type=self._get_class_type(node_id),
|
||||
cache_key=cache_key,
|
||||
cache_key_bytes=serialize_cache_key(cache_key)
|
||||
)
|
||||
cache_value = CacheValue(outputs=value.outputs)
|
||||
|
||||
for provider in get_cache_providers():
|
||||
try:
|
||||
if provider.should_cache(context, cache_value):
|
||||
provider.on_store(context, cache_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
||||
|
||||
def _check_providers_lookup(self, node_id, cache_key):
|
||||
"""Check external providers for cached result."""
|
||||
from comfy_execution.cache_provider import (
|
||||
has_cache_providers, get_cache_providers,
|
||||
CacheContext, CacheValue,
|
||||
serialize_cache_key, contains_nan, logger
|
||||
)
|
||||
|
||||
if self._is_subcache:
|
||||
return None
|
||||
if not has_cache_providers():
|
||||
return None
|
||||
if contains_nan(cache_key):
|
||||
return None
|
||||
|
||||
context = CacheContext(
|
||||
prompt_id=self._current_prompt_id,
|
||||
node_id=node_id,
|
||||
class_type=self._get_class_type(node_id),
|
||||
cache_key=cache_key,
|
||||
cache_key_bytes=serialize_cache_key(cache_key)
|
||||
)
|
||||
|
||||
for provider in get_cache_providers():
|
||||
try:
|
||||
if not provider.should_cache(context):
|
||||
continue
|
||||
result = provider.on_lookup(context)
|
||||
if result is not None:
|
||||
if not isinstance(result, CacheValue):
|
||||
logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
||||
continue
|
||||
if not isinstance(result.outputs, (list, tuple)):
|
||||
logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
||||
continue
|
||||
# Import CacheEntry here to avoid circular import at module level
|
||||
from execution import CacheEntry
|
||||
return CacheEntry(ui={}, outputs=list(result.outputs))
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _is_cacheable_value(self, value):
|
||||
"""Check if value is a CacheEntry (not objects cache)."""
|
||||
return hasattr(value, 'outputs') and hasattr(value, 'ui')
|
||||
|
||||
def _get_class_type(self, node_id):
|
||||
"""Get class_type for a node."""
|
||||
if not self.initialized or not self.dynprompt:
|
||||
return ''
|
||||
try:
|
||||
return self.dynprompt.get_node(node_id).get('class_type', '')
|
||||
except Exception:
|
||||
return ''
|
||||
|
||||
async def _ensure_subcache(self, node_id, children_ids):
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
subcache._is_subcache = True # Mark as subcache - excludes from external caching
|
||||
subcache._current_prompt_id = self._current_prompt_id # Propagate prompt ID
|
||||
self.subcaches[subcache_key] = subcache
|
||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
@ -701,7 +701,14 @@ class Noise_EmptyNoise:
|
||||
|
||||
def generate_noise(self, input_latent):
|
||||
latent_image = input_latent["samples"]
|
||||
return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
if latent_image.is_nested:
|
||||
tensors = latent_image.unbind()
|
||||
zeros = []
|
||||
for t in tensors:
|
||||
zeros.append(torch.zeros(t.shape, dtype=t.dtype, layout=t.layout, device="cpu"))
|
||||
return comfy.nested_tensor.NestedTensor(zeros)
|
||||
else:
|
||||
return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
|
||||
|
||||
class Noise_RandomNoise:
|
||||
|
||||
@ -223,11 +223,24 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return frame_idx, latent_idx
|
||||
|
||||
@classmethod
|
||||
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
|
||||
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1):
|
||||
keyframe_idxs, _ = get_keyframe_idxs(cond)
|
||||
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
|
||||
pixel_coords[:, 0] += frame_idx
|
||||
|
||||
# The following adjusts keyframe end positions for small grid IC-LoRA.
|
||||
# After dilation, the small grid has the same size and position as the large grid,
|
||||
# but each token encodes a larger image patch. We adjust the end position (not start)
|
||||
# so that RoPE represents the correct middle point of each token.
|
||||
# keyframe_idxs dims: (batch, spatial_dim [t,h,w], token_id, [start, end])
|
||||
# We only adjust h,w (not t) in dim 1, and only end (not start) in dim 3.
|
||||
spatial_end_offset = (latent_downscale_factor - 1) * torch.tensor(
|
||||
scale_factors[1:],
|
||||
device=pixel_coords.device,
|
||||
).view(1, -1, 1, 1)
|
||||
pixel_coords[:, 1:, :, 1:] += spatial_end_offset.to(pixel_coords.dtype)
|
||||
|
||||
if keyframe_idxs is None:
|
||||
keyframe_idxs = pixel_coords
|
||||
else:
|
||||
@ -235,12 +248,12 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||
|
||||
@classmethod
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128):
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1):
|
||||
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
|
||||
raise ValueError("Adding guide to a combined AV latent is not supported.")
|
||||
|
||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
|
||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
|
||||
|
||||
if guide_mask is not None:
|
||||
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
|
||||
|
||||
137
execution.py
137
execution.py
@ -669,6 +669,22 @@ class PromptExecutor:
|
||||
}
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||||
"""Notify external cache providers of prompt lifecycle events."""
|
||||
from comfy_execution.cache_provider import has_cache_providers, get_cache_providers, logger
|
||||
|
||||
if not has_cache_providers():
|
||||
return
|
||||
|
||||
for provider in get_cache_providers():
|
||||
try:
|
||||
if event == "start":
|
||||
provider.on_prompt_start(prompt_id)
|
||||
elif event == "end":
|
||||
provider.on_prompt_end(prompt_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
@ -685,66 +701,77 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
# Set prompt ID on caches for external provider integration
|
||||
for cache in self.caches.all:
|
||||
cache._current_prompt_id = prompt_id
|
||||
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
# Notify external cache providers of prompt start
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
finally:
|
||||
# Notify external cache providers of prompt end
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
|
||||
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.37.11
|
||||
comfyui-workflow-templates==0.8.15
|
||||
comfyui-workflow-templates==0.8.24
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
|
||||
370
tests-unit/execution_test/test_cache_provider.py
Normal file
370
tests-unit/execution_test/test_cache_provider.py
Normal file
@ -0,0 +1,370 @@
|
||||
"""Tests for external cache provider API."""
|
||||
|
||||
import importlib.util
|
||||
import pytest
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _torch_available() -> bool:
|
||||
"""Check if PyTorch is available."""
|
||||
return importlib.util.find_spec("torch") is not None
|
||||
|
||||
|
||||
from comfy_execution.cache_provider import (
|
||||
CacheProvider,
|
||||
CacheContext,
|
||||
CacheValue,
|
||||
register_cache_provider,
|
||||
unregister_cache_provider,
|
||||
get_cache_providers,
|
||||
has_cache_providers,
|
||||
clear_cache_providers,
|
||||
serialize_cache_key,
|
||||
contains_nan,
|
||||
estimate_value_size,
|
||||
_canonicalize,
|
||||
)
|
||||
|
||||
|
||||
class TestCanonicalize:
|
||||
"""Test _canonicalize function for deterministic ordering."""
|
||||
|
||||
def test_frozenset_ordering_is_deterministic(self):
|
||||
"""Frozensets should produce consistent canonical form regardless of iteration order."""
|
||||
# Create two frozensets with same content
|
||||
fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
|
||||
fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
|
||||
|
||||
result1 = _canonicalize(fs1)
|
||||
result2 = _canonicalize(fs2)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_nested_frozenset_ordering(self):
|
||||
"""Nested frozensets should also be deterministically ordered."""
|
||||
inner1 = frozenset([1, 2, 3])
|
||||
inner2 = frozenset([3, 2, 1])
|
||||
|
||||
fs1 = frozenset([("key", inner1)])
|
||||
fs2 = frozenset([("key", inner2)])
|
||||
|
||||
result1 = _canonicalize(fs1)
|
||||
result2 = _canonicalize(fs2)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_dict_ordering(self):
|
||||
"""Dicts should be sorted by key."""
|
||||
d1 = {"z": 1, "a": 2, "m": 3}
|
||||
d2 = {"a": 2, "m": 3, "z": 1}
|
||||
|
||||
result1 = _canonicalize(d1)
|
||||
result2 = _canonicalize(d2)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_tuple_preserved(self):
|
||||
"""Tuples should be marked and preserved."""
|
||||
t = (1, 2, 3)
|
||||
result = _canonicalize(t)
|
||||
|
||||
assert result[0] == "__tuple__"
|
||||
assert result[1] == [1, 2, 3]
|
||||
|
||||
def test_list_preserved(self):
|
||||
"""Lists should be recursively canonicalized."""
|
||||
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
|
||||
result = _canonicalize(lst)
|
||||
|
||||
# First element should be dict with sorted keys
|
||||
assert result[0] == {"a": 1, "b": 2}
|
||||
# Second element should be canonicalized frozenset
|
||||
assert result[1][0] == "__frozenset__"
|
||||
|
||||
def test_primitives_unchanged(self):
|
||||
"""Primitive types should pass through unchanged."""
|
||||
assert _canonicalize(42) == 42
|
||||
assert _canonicalize(3.14) == 3.14
|
||||
assert _canonicalize("hello") == "hello"
|
||||
assert _canonicalize(True) is True
|
||||
assert _canonicalize(None) is None
|
||||
|
||||
def test_bytes_converted(self):
|
||||
"""Bytes should be converted to hex string."""
|
||||
b = b"\x00\xff"
|
||||
result = _canonicalize(b)
|
||||
|
||||
assert result[0] == "__bytes__"
|
||||
assert result[1] == "00ff"
|
||||
|
||||
def test_set_ordering(self):
|
||||
"""Sets should be sorted like frozensets."""
|
||||
s1 = {3, 1, 2}
|
||||
s2 = {1, 2, 3}
|
||||
|
||||
result1 = _canonicalize(s1)
|
||||
result2 = _canonicalize(s2)
|
||||
|
||||
assert result1 == result2
|
||||
assert result1[0] == "__set__"
|
||||
|
||||
|
||||
class TestSerializeCacheKey:
|
||||
"""Test serialize_cache_key for deterministic hashing."""
|
||||
|
||||
def test_same_content_same_hash(self):
|
||||
"""Same content should produce same hash."""
|
||||
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||
|
||||
hash1 = serialize_cache_key(key1)
|
||||
hash2 = serialize_cache_key(key2)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_different_content_different_hash(self):
|
||||
"""Different content should produce different hash."""
|
||||
key1 = frozenset([("node_1", "value_a")])
|
||||
key2 = frozenset([("node_1", "value_b")])
|
||||
|
||||
hash1 = serialize_cache_key(key1)
|
||||
hash2 = serialize_cache_key(key2)
|
||||
|
||||
assert hash1 != hash2
|
||||
|
||||
def test_returns_bytes(self):
|
||||
"""Should return bytes (SHA256 digest)."""
|
||||
key = frozenset([("test", 123)])
|
||||
result = serialize_cache_key(key)
|
||||
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) == 32 # SHA256 produces 32 bytes
|
||||
|
||||
def test_complex_nested_structure(self):
|
||||
"""Complex nested structures should hash deterministically."""
|
||||
# Note: frozensets can only contain hashable types, so we use
|
||||
# nested frozensets of tuples to represent dict-like structures
|
||||
key = frozenset([
|
||||
("node_1", frozenset([
|
||||
("input_a", ("tuple", "value")),
|
||||
("input_b", frozenset([("nested", "dict")])),
|
||||
])),
|
||||
("node_2", frozenset([
|
||||
("param", 42),
|
||||
])),
|
||||
])
|
||||
|
||||
# Hash twice to verify determinism
|
||||
hash1 = serialize_cache_key(key)
|
||||
hash2 = serialize_cache_key(key)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_dict_in_cache_key(self):
|
||||
"""Dicts passed directly to serialize_cache_key should work."""
|
||||
# This tests the _canonicalize function's ability to handle dicts
|
||||
key = {"node_1": {"input": "value"}, "node_2": 42}
|
||||
|
||||
hash1 = serialize_cache_key(key)
|
||||
hash2 = serialize_cache_key(key)
|
||||
|
||||
assert hash1 == hash2
|
||||
assert isinstance(hash1, bytes)
|
||||
assert len(hash1) == 32
|
||||
|
||||
|
||||
class TestContainsNan:
|
||||
"""Test contains_nan utility function."""
|
||||
|
||||
def test_nan_float_detected(self):
|
||||
"""NaN floats should be detected."""
|
||||
assert contains_nan(float('nan')) is True
|
||||
|
||||
def test_regular_float_not_nan(self):
|
||||
"""Regular floats should not be detected as NaN."""
|
||||
assert contains_nan(3.14) is False
|
||||
assert contains_nan(0.0) is False
|
||||
assert contains_nan(-1.5) is False
|
||||
|
||||
def test_infinity_not_nan(self):
|
||||
"""Infinity is not NaN."""
|
||||
assert contains_nan(float('inf')) is False
|
||||
assert contains_nan(float('-inf')) is False
|
||||
|
||||
def test_nan_in_list(self):
|
||||
"""NaN in list should be detected."""
|
||||
assert contains_nan([1, 2, float('nan'), 4]) is True
|
||||
assert contains_nan([1, 2, 3, 4]) is False
|
||||
|
||||
def test_nan_in_tuple(self):
|
||||
"""NaN in tuple should be detected."""
|
||||
assert contains_nan((1, float('nan'))) is True
|
||||
assert contains_nan((1, 2, 3)) is False
|
||||
|
||||
def test_nan_in_frozenset(self):
|
||||
"""NaN in frozenset should be detected."""
|
||||
assert contains_nan(frozenset([1, float('nan')])) is True
|
||||
assert contains_nan(frozenset([1, 2, 3])) is False
|
||||
|
||||
def test_nan_in_dict_value(self):
|
||||
"""NaN in dict value should be detected."""
|
||||
assert contains_nan({"key": float('nan')}) is True
|
||||
assert contains_nan({"key": 42}) is False
|
||||
|
||||
def test_nan_in_nested_structure(self):
|
||||
"""NaN in deeply nested structure should be detected."""
|
||||
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
|
||||
assert contains_nan(nested) is True
|
||||
|
||||
def test_non_numeric_types(self):
|
||||
"""Non-numeric types should not be NaN."""
|
||||
assert contains_nan("string") is False
|
||||
assert contains_nan(None) is False
|
||||
assert contains_nan(True) is False
|
||||
|
||||
|
||||
class TestEstimateValueSize:
|
||||
"""Test estimate_value_size utility function."""
|
||||
|
||||
def test_empty_outputs(self):
|
||||
"""Empty outputs should have zero size."""
|
||||
value = CacheValue(outputs=[])
|
||||
assert estimate_value_size(value) == 0
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _torch_available(),
|
||||
reason="PyTorch not available"
|
||||
)
|
||||
def test_tensor_size_estimation(self):
|
||||
"""Tensor size should be estimated correctly."""
|
||||
import torch
|
||||
|
||||
# 1000 float32 elements = 4000 bytes
|
||||
tensor = torch.zeros(1000, dtype=torch.float32)
|
||||
value = CacheValue(outputs=[[tensor]])
|
||||
|
||||
size = estimate_value_size(value)
|
||||
assert size == 4000
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _torch_available(),
|
||||
reason="PyTorch not available"
|
||||
)
|
||||
def test_nested_tensor_in_dict(self):
|
||||
"""Tensors nested in dicts should be counted."""
|
||||
import torch
|
||||
|
||||
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||
value = CacheValue(outputs=[[{"samples": tensor}]])
|
||||
|
||||
size = estimate_value_size(value)
|
||||
assert size == 400
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Test cache provider registration and retrieval."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear providers before each test."""
|
||||
clear_cache_providers()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clear providers after each test."""
|
||||
clear_cache_providers()
|
||||
|
||||
def test_register_provider(self):
|
||||
"""Provider should be registered successfully."""
|
||||
provider = MockCacheProvider()
|
||||
register_cache_provider(provider)
|
||||
|
||||
assert has_cache_providers() is True
|
||||
providers = get_cache_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0] is provider
|
||||
|
||||
def test_unregister_provider(self):
|
||||
"""Provider should be unregistered successfully."""
|
||||
provider = MockCacheProvider()
|
||||
register_cache_provider(provider)
|
||||
unregister_cache_provider(provider)
|
||||
|
||||
assert has_cache_providers() is False
|
||||
|
||||
def test_multiple_providers(self):
|
||||
"""Multiple providers can be registered."""
|
||||
provider1 = MockCacheProvider()
|
||||
provider2 = MockCacheProvider()
|
||||
|
||||
register_cache_provider(provider1)
|
||||
register_cache_provider(provider2)
|
||||
|
||||
providers = get_cache_providers()
|
||||
assert len(providers) == 2
|
||||
|
||||
def test_duplicate_registration_ignored(self):
|
||||
"""Registering same provider twice should be ignored."""
|
||||
provider = MockCacheProvider()
|
||||
|
||||
register_cache_provider(provider)
|
||||
register_cache_provider(provider) # Should be ignored
|
||||
|
||||
providers = get_cache_providers()
|
||||
assert len(providers) == 1
|
||||
|
||||
def test_clear_providers(self):
|
||||
"""clear_cache_providers should remove all providers."""
|
||||
provider1 = MockCacheProvider()
|
||||
provider2 = MockCacheProvider()
|
||||
|
||||
register_cache_provider(provider1)
|
||||
register_cache_provider(provider2)
|
||||
clear_cache_providers()
|
||||
|
||||
assert has_cache_providers() is False
|
||||
assert len(get_cache_providers()) == 0
|
||||
|
||||
|
||||
class TestCacheContext:
|
||||
"""Test CacheContext dataclass."""
|
||||
|
||||
def test_context_creation(self):
|
||||
"""CacheContext should be created with all fields."""
|
||||
context = CacheContext(
|
||||
prompt_id="prompt-123",
|
||||
node_id="node-456",
|
||||
class_type="KSampler",
|
||||
cache_key=frozenset([("test", "value")]),
|
||||
cache_key_bytes=b"hash_bytes",
|
||||
)
|
||||
|
||||
assert context.prompt_id == "prompt-123"
|
||||
assert context.node_id == "node-456"
|
||||
assert context.class_type == "KSampler"
|
||||
assert context.cache_key == frozenset([("test", "value")])
|
||||
assert context.cache_key_bytes == b"hash_bytes"
|
||||
|
||||
|
||||
class TestCacheValue:
|
||||
"""Test CacheValue dataclass."""
|
||||
|
||||
def test_value_creation(self):
|
||||
"""CacheValue should be created with outputs."""
|
||||
outputs = [[{"samples": "tensor_data"}]]
|
||||
value = CacheValue(outputs=outputs)
|
||||
|
||||
assert value.outputs == outputs
|
||||
|
||||
|
||||
class MockCacheProvider(CacheProvider):
|
||||
"""Mock cache provider for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.lookups = []
|
||||
self.stores = []
|
||||
|
||||
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
self.lookups.append(context)
|
||||
return None
|
||||
|
||||
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
self.stores.append((context, value))
|
||||
Loading…
Reference in New Issue
Block a user