Merge branch 'Comfy-Org:master' into qwen-image-vae

This commit is contained in:
brucew4yn3rp 2026-06-28 22:02:20 -04:00 committed by GitHub
commit 77cb84873a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1015 additions and 391 deletions

38
.github/workflows/ci-cursor-review.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: CI - Cursor Review
# Thin caller for the shared reusable cursor-review workflow in
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
# consolidation, prompts, extract/post/notify scripts) lives there as the
# single source of truth, so this repo only carries the repo-specific diff
# excludes.
on:
pull_request:
types: [labeled, unlabeled]
concurrency:
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
cancel-in-progress: true
jobs:
cursor-review:
if: github.event.label.name == 'cursor-review'
permissions:
contents: read
pull-requests: write
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
# from the same commit as the workflow definition.
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
with:
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
diff_excludes: >-
:!**/.claude/**
:!**/dist/**
:!**/vendor/**
:!**/*.generated.*
:!**/*.min.js
:!**/*.min.css
secrets:
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}

View File

@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
if (want_requant and len(fns) == 0 or update_weight): if (want_requant and len(fns) == 0 or update_weight):
seed = comfy.utils.string_to_seed(s.seed_key) seed = comfy.utils.string_to_seed(s.seed_key)
if isinstance(orig, QuantizedTensor): if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
else: else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed) y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0: if want_requant and len(fns) == 0:
@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
if ts is None or bs is None: if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
scales = {"scale": ts, "block_scale": bs} scales = {"scale": ts, "block_scale": bs}
elif module.quant_format == "int8_tensorwise":
scale = pop_scale("weight_scale")
if scale is None:
raise ValueError(f"Missing INT8 weight scale for layer {layer_name}")
scales = {"scale": scale}
params_conf = layer_conf.get("params", {})
if not isinstance(params_conf, dict):
params_conf = {}
if layer_conf.get("convrot", params_conf.get("convrot", False)):
scales["convrot"] = True
scales["convrot_groupsize"] = int(
layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256))
)
else: else:
raise ValueError(f"Unsupported quantization format: {module.quant_format}") raise ValueError(f"Unsupported quantization format: {module.quant_format}")
@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr
quant_conf = {"format": module.quant_format} quant_conf = {"format": module.quant_format}
if getattr(module, '_full_precision_mm_config', False): if getattr(module, '_full_precision_mm_config', False):
quant_conf["full_precision_matrix_mult"] = True quant_conf["full_precision_matrix_mult"] = True
params = getattr(module.weight, "_params", None)
if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False):
quant_conf["convrot"] = True
quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256)
if extra_quant_conf: if extra_quant_conf:
quant_conf.update(extra_quant_conf) quant_conf.update(extra_quant_conf)
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8) sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False): def forward_comfy_cast_weights(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant) self,
input,
compute_dtype=None,
want_requant=False,
weight_only_quant=False,
):
if weight_only_quant:
weight, bias, offload_stream = cast_bias_weight(
self,
input=None,
dtype=self.weight.dtype,
device=input.device,
bias_dtype=input.dtype,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=want_requant,
)
weight = weight.to(dtype=input.dtype)
else:
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=want_requant,
)
x = self._forward(input, weight, bias) x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream) uncast_bias_weight(self, weight, bias, offload_stream)
return x return x
@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
not getattr(self, 'comfy_force_cast_weights', False) and not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0 len(self.weight_function) == 0 and len(self.bias_function) == 0
) )
quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True)
# Training path: quantized forward with compute_dtype backward via autograd function # Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized): if (input.requires_grad and _use_quantized and quantize_input):
weight, bias, offload_stream = cast_bias_weight( weight, bias, offload_stream = cast_bias_weight(
self, self,
@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
return output return output
# Inference path (unchanged) # Inference path (unchanged)
if _use_quantized: if _use_quantized and quantize_input:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None) scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor)) weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor)
output = self.forward_comfy_cast_weights(
input,
compute_dtype,
want_requant=isinstance(input, QuantizedTensor),
weight_only_quant=weight_only_quant,
)
# Reshape output back to 3D if input was 3D # Reshape output back to 3D if input was 3D
if reshaped_3d: if reshaped_3d:
@ -1257,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None: if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else: else:
weight = weight.to(self.weight.dtype) weight = weight.to(self.weight.dtype)
if return_weight: if return_weight:

View File

@ -10,6 +10,7 @@ try:
QuantizedLayout, QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout, TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout as _CKNvfp4Layout, TensorCoreNVFP4Layout as _CKNvfp4Layout,
TensorWiseINT8Layout as _CKTensorWiseINT8Layout,
register_layout_op, register_layout_op,
register_layout_class, register_layout_class,
get_layout_class, get_layout_class,
@ -47,6 +48,9 @@ except ImportError as e:
class _CKNvfp4Layout: class _CKNvfp4Layout:
pass pass
class _CKTensorWiseINT8Layout:
pass
def register_layout_class(name, cls): def register_layout_class(name, cls):
pass pass
@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
# Backward compatibility alias - default to E4M3 # Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
TensorWiseINT8Layout = _CKTensorWiseINT8Layout
# ============================================================================== # ==============================================================================
@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout)
if _CK_MXFP8_AVAILABLE: if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE:
"group_size": 32, "group_size": 32,
} }
QUANT_ALGOS["int8_tensorwise"] = {
"storage_t": torch.int8,
"parameters": {"weight_scale"},
"comfy_tensor_layout": "TensorWiseINT8Layout",
"quantize_input": False,
}
# ============================================================================== # ==============================================================================
# Re-exports for backward compatibility # Re-exports for backward compatibility
@ -226,6 +239,7 @@ __all__ = [
"TensorCoreFP8E4M3Layout", "TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout", "TensorCoreFP8E5M2Layout",
"TensorCoreNVFP4Layout", "TensorCoreNVFP4Layout",
"TensorWiseINT8Layout",
"QUANT_ALGOS", "QUANT_ALGOS",
"register_layout_op", "register_layout_op",
] ]

View File

@ -891,6 +891,14 @@ class Tracks(ComfyTypeIO):
track_visibility: torch.Tensor track_visibility: torch.Tensor
Type = TrackDict Type = TrackDict
@comfytype(io_type="DICT")
class Dict(ComfyTypeIO):
Type = dict
@comfytype(io_type="ARRAY")
class Array(ComfyTypeIO):
Type = list
@comfytype(io_type="COMFY_MULTITYPED_V3") @comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType: class MultiType:
Type = Any Type = Any
@ -1279,6 +1287,19 @@ class Color(ComfyTypeIO):
def as_dict(self): def as_dict(self):
return super().as_dict() return super().as_dict()
@comfytype(io_type="COLORS")
class Colors(ComfyTypeIO):
Type = list[Color.Type]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[str]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = []
@comfytype(io_type="BOUNDING_BOX") @comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO): class BoundingBox(ComfyTypeIO):
class BoundingBoxDict(TypedDict): class BoundingBoxDict(TypedDict):
@ -1326,6 +1347,20 @@ class Curve(ComfyTypeIO):
return d return d
@comfytype(io_type="BOUNDING_BOXES")
class BoundingBoxes(ComfyTypeIO):
class BoundingBoxWithMetadata(BoundingBox.BoundingBoxDict):
metadata: dict
Type = list[BoundingBoxWithMetadata]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[dict]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = []
@comfytype(io_type="HISTOGRAM") @comfytype(io_type="HISTOGRAM")
class Histogram(ComfyTypeIO): class Histogram(ComfyTypeIO):
"""A histogram represented as a list of bin counts.""" """A histogram represented as a list of bin counts."""
@ -2376,6 +2411,8 @@ __all__ = [
"AnyType", "AnyType",
"MultiType", "MultiType",
"Tracks", "Tracks",
"Dict",
"Array",
"Color", "Color",
# Dynamic Types # Dynamic Types
"MatchType", "MatchType",
@ -2394,6 +2431,8 @@ __all__ = [
"PriceBadgeDepends", "PriceBadgeDepends",
"PriceBadge", "PriceBadge",
"BoundingBox", "BoundingBox",
"BoundingBoxes",
"Colors",
"Curve", "Curve",
"Histogram", "Histogram",
"Range", "Range",

View File

@ -163,15 +163,31 @@ class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.") asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
# Dollars per 1K tokens, keyed by (model_id, has_video_input). # Dollars per 1K tokens, keyed by (model_id, has_video_input, resolution).
SEEDANCE2_PRICE_PER_1K_TOKENS = { SEEDANCE2_PRICE_PER_1K_TOKENS = {
("dreamina-seedance-2-0-260128", False): 0.007, ("dreamina-seedance-2-0-260128", False, "480p"): 0.007,
("dreamina-seedance-2-0-260128", True): 0.0043, ("dreamina-seedance-2-0-260128", True, "480p"): 0.0043,
("dreamina-seedance-2-0-fast-260128", False): 0.0056, ("dreamina-seedance-2-0-260128", False, "720p"): 0.007,
("dreamina-seedance-2-0-fast-260128", True): 0.0033, ("dreamina-seedance-2-0-260128", True, "720p"): 0.0043,
("dreamina-seedance-2-0-260128", False, "1080p"): 0.0077,
("dreamina-seedance-2-0-260128", True, "1080p"): 0.0047,
("dreamina-seedance-2-0-260128", False, "4k"): 0.004,
("dreamina-seedance-2-0-260128", True, "4k"): 0.0024,
("dreamina-seedance-2-0-fast-260128", False, "480p"): 0.0056,
("dreamina-seedance-2-0-fast-260128", True, "480p"): 0.0033,
("dreamina-seedance-2-0-fast-260128", False, "720p"): 0.0056,
("dreamina-seedance-2-0-fast-260128", True, "720p"): 0.0033,
("dreamina-seedance-2-0-mini", False, "480p"): 0.0035,
("dreamina-seedance-2-0-mini", True, "480p"): 0.0021,
("dreamina-seedance-2-0-mini", False, "720p"): 0.0035,
("dreamina-seedance-2-0-mini", True, "720p"): 0.0021,
} }
def seedance2_price_per_1k_tokens(model_id: str, has_video_input: bool, resolution: str) -> float | None:
return SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input, resolution))
RECOMMENDED_PRESETS = [ RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024), ("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152), ("864x1152 (3:4)", 864, 1152),
@ -266,6 +282,10 @@ SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
"480p": {"min": 409_600, "max": 927_408}, "480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408}, "720p": {"min": 409_600, "max": 927_408},
}, },
"dreamina-seedance-2-0-mini": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
},
} }
# The time in this dictionary are given for 10 seconds duration. # The time in this dictionary are given for 10 seconds duration.

View File

@ -15,7 +15,6 @@ from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS_SEEDREAM_4_0, RECOMMENDED_PRESETS_SEEDREAM_4_0,
RECOMMENDED_PRESETS_SEEDREAM_4_5, RECOMMENDED_PRESETS_SEEDREAM_4_5,
RECOMMENDED_PRESETS_SEEDREAM_5_LITE, RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
SEEDANCE2_PRICE_PER_1K_TOKENS,
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS, SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
VIDEO_TASKS_EXECUTION_TIME, VIDEO_TASKS_EXECUTION_TIME,
GetAssetResponse, GetAssetResponse,
@ -40,6 +39,7 @@ from comfy_api_nodes.apis.bytedance import (
TaskVideoContentUrl, TaskVideoContentUrl,
Text2ImageTaskCreationRequest, Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest, Text2VideoTaskCreationRequest,
seedance2_price_per_1k_tokens,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -89,6 +89,7 @@ BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/cont
SEEDANCE_MODELS = { SEEDANCE_MODELS = {
"Seedance 2.0": "dreamina-seedance-2-0-260128", "Seedance 2.0": "dreamina-seedance-2-0-260128",
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128", "Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
"Seedance 2.0 Mini": "dreamina-seedance-2-0-mini",
} }
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"} DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
@ -141,7 +142,7 @@ SEEDANCE2_RATIO_WH = {
"9:16": (9, 16), "9:16": (9, 16),
"21:9": (21, 9), "21:9": (21, 9),
} }
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080} SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080, "4k": 2160}
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]: def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
@ -377,9 +378,9 @@ async def _seedance_virtual_library_upload_video_asset(
return f"asset://{create_resp.asset_id}" return f"asset://{create_resp.asset_id}"
def _seedance2_price_extractor(model_id: str, has_video_input: bool): def _seedance2_price_extractor(model_id: str, has_video_input: bool, resolution: str):
"""Returns a price_extractor closure for Seedance 2.0 poll_op.""" """Returns a price_extractor closure for Seedance 2.0 poll_op."""
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) rate = seedance2_price_per_1k_tokens(model_id, has_video_input, resolution)
if rate is None: if rate is None:
return None return None
@ -1621,10 +1622,12 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"model", "model",
options=[ options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])), IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p", "4k"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])), IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
IO.DynamicCombo.Option("Seedance 2.0 Mini", _seedance2_text_inputs(["480p", "720p"])),
], ],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
"Mini for the fastest, lowest-cost generation.",
), ),
IO.Int.Input( IO.Int.Input(
"seed", "seed",
@ -1660,11 +1663,16 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
$rate480 := 10044; $rate480 := 10044;
$rate720 := 21600; $rate720 := 21600;
$rate1080 := 48800; $rate1080 := 48800;
$rate4k := 195200;
$m := widgets.model; $m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration"); $dur := $lookup(widgets, "model.duration");
$rate := $res = "1080p" ? $rate1080 : $pricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "mini") ? 0.005005 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 : $res = "720p" ? $rate720 :
$rate480; $rate480;
$cost := $dur * $rate * $pricePer1K / 1000; $cost := $dur * $rate * $pricePer1K / 1000;
@ -1703,7 +1711,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"), ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: r.status, status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]),
poll_interval=9, poll_interval=9,
) )
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
@ -1724,14 +1732,19 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
options=[ options=[
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"Seedance 2.0", "Seedance 2.0",
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"), _seedance2_text_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"),
), ),
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"Seedance 2.0 Fast", "Seedance 2.0 Fast",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"), _seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
), ),
IO.DynamicCombo.Option(
"Seedance 2.0 Mini",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
], ],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
"Mini for the fastest, lowest-cost generation.",
), ),
IO.Image.Input( IO.Image.Input(
"first_frame", "first_frame",
@ -1791,11 +1804,16 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
$rate480 := 10044; $rate480 := 10044;
$rate720 := 21600; $rate720 := 21600;
$rate1080 := 48800; $rate1080 := 48800;
$rate4k := 195200;
$m := widgets.model; $m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration"); $dur := $lookup(widgets, "model.duration");
$rate := $res = "1080p" ? $rate1080 : $pricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "mini") ? 0.005005 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 : $res = "720p" ? $rate720 :
$rate480; $rate480;
$cost := $dur * $rate * $pricePer1K / 1000; $cost := $dur * $rate * $pricePer1K / 1000;
@ -1913,7 +1931,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"), ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: r.status, status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]),
poll_interval=9, poll_interval=9,
) )
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
@ -2010,14 +2028,19 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
options=[ options=[
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"Seedance 2.0", "Seedance 2.0",
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"), _seedance2_reference_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"),
), ),
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"Seedance 2.0 Fast", "Seedance 2.0 Fast",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"), _seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
), ),
IO.DynamicCombo.Option(
"Seedance 2.0 Mini",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
], ],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
"Mini for the fastest, lowest-cost generation.",
), ),
IO.Int.Input( IO.Int.Input(
"seed", "seed",
@ -2056,13 +2079,21 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
$rate480 := 10044; $rate480 := 10044;
$rate720 := 21600; $rate720 := 21600;
$rate1080 := 48800; $rate1080 := 48800;
$rate4k := 195200;
$m := widgets.model; $m := widgets.model;
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0; $hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration"); $dur := $lookup(widgets, "model.duration");
$rate := $res = "1080p" ? $rate1080 : $noVideoPricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "mini") ? 0.005005 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$videoPricePer1K := $res = "4k" ? 0.003432 :
$res = "1080p" ? 0.006721 :
$contains($m, "mini") ? 0.003003 :
$contains($m, "fast") ? 0.004719 : 0.006149;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 : $res = "720p" ? $rate720 :
$rate480; $rate480;
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000; $noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
@ -2258,7 +2289,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"), ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: r.status, status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input), price_extractor=_seedance2_price_extractor(
model_id, has_video_input=has_video_input, resolution=model["resolution"]
),
poll_interval=9, poll_interval=9,
) )
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))

View File

@ -0,0 +1,23 @@
def hex_to_rgb(value: str) -> tuple[int, int, int]:
h = value.lstrip("#")
if len(h) != 6:
return (255, 255, 255)
try:
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
except ValueError:
return (255, 255, 255)
def readable_color(rgb: tuple[int, int, int]) -> tuple[int, int, int]:
r, g, b = rgb
lum = 0.299 * r + 0.587 * g + 0.114 * b
if lum >= 130:
return (r, g, b)
t = (130 - lum) / (255 - lum)
return (round(r + (255 - r) * t), round(g + (255 - g) * t), round(b + (255 - b) * t))
def normalize_palette(colors) -> list[str]:
if isinstance(colors, dict):
colors = colors.values()
return [c.upper() for c in colors if isinstance(c, str) and c]

View File

@ -0,0 +1,253 @@
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.color_util import hex_to_rgb, normalize_palette, readable_color
_PREVIEW_LONG_EDGE = 1024
_PREVIEW_DIM = 0.25
def pixels_to_fractions(box: dict, width: int, height: int) -> dict:
w = width or 1
h = height or 1
return {
"x": box.get("x", 0) / w,
"y": box.get("y", 0) / h,
"w": box.get("width", 0) / w,
"h": box.get("height", 0) / h,
}
def fractions_to_pixels(box: dict, width: int, height: int) -> dict:
x, y = box.get("x", 0.0), box.get("y", 0.0)
w, h = box.get("w", 0.0), box.get("h", 0.0)
if w < 0:
x, w = x + w, -w
if h < 0:
y, h = y + h, -h
return {
"x": round(x * width),
"y": round(y * height),
"width": round(w * width),
"height": round(h * height),
}
def fractions_to_bbox_frame(boxes: list, width: int, height: int) -> list:
pixels = [
fractions_to_pixels(box, width, height)
for box in boxes
if isinstance(box, dict)
]
return [pixels] if pixels else []
def _font(size: int):
try:
return ImageFont.load_default(size)
except Exception:
return ImageFont.load_default()
def _wrap(draw, text: str, font, max_w: float) -> list[str]:
lines = []
for para in text.split("\n"):
line = ""
for word in para.split():
test = word if not line else line + " " + word
if line and draw.textlength(test, font=font) > max_w:
lines.append(line)
line = word
else:
line = test
lines.append(line)
return lines
def _bg_from_image(image) -> Image.Image | None:
if image is None:
return None
try:
arr = (image[0].detach().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(arr)
except Exception:
return None
def render_preview(regions, width, height, bg=None):
if bg is not None:
iw, ih = bg.size
long_edge = max(iw, ih) or 1
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
rw, rh = max(1, round(iw * scale)), max(1, round(ih * scale))
base = bg.convert("RGB").resize((rw, rh), Image.LANCZOS)
base = ImageEnhance.Brightness(base).enhance(_PREVIEW_DIM)
img = base.convert("RGBA")
else:
long_edge = max(width, height) or 1
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
rw, rh = max(1, round(width * scale)), max(1, round(height * scale))
grey = round(_PREVIEW_DIM * 128)
img = Image.new("RGBA", (rw, rh), (grey, grey, grey, 255))
overlay = Image.new("RGBA", (rw, rh), (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
fs = max(10, round(rh / 64))
font = _font(fs)
tag_font = _font(max(9, fs - 2))
line_h = fs + 2
for i, region in enumerate(regions):
if not isinstance(region, dict):
continue
palette = [c for c in (region.get("palette") or []) if c]
r, g, b = hex_to_rgb(palette[0]) if palette else (140, 140, 140)
x1 = max(0, min(rw, round(region.get("x", 0) * rw)))
y1 = max(0, min(rh, round(region.get("y", 0) * rh)))
x2 = max(0, min(rw, round((region.get("x", 0) + region.get("w", 0)) * rw)))
y2 = max(0, min(rh, round((region.get("y", 0) + region.get("h", 0)) * rh)))
if x2 < x1:
x1, x2 = x2, x1
if y2 < y1:
y1, y2 = y2, y1
draw.rectangle([x1, y1, x2, y2], outline=(r, g, b, 255), width=2)
swatches = palette[:5]
if swatches and (x2 - x1) > 2:
sh = max(5, fs // 2)
seg = (x2 - x1) / len(swatches)
for p, hexc in enumerate(swatches):
sx = x1 + round(p * seg)
draw.rectangle([sx, y1, x1 + round((p + 1) * seg), y1 + sh], fill=hex_to_rgb(hexc))
etype = "text" if region.get("type") == "text" else "obj"
tag = str(i + 1).zfill(2)
tw = draw.textlength(tag, font=tag_font)
draw.rectangle([x1, y1, x1 + tw + 6, y1 + fs + 2], fill=(r, g, b, 255))
tag_fill = (0, 0, 0, 255) if (0.299 * r + 0.587 * g + 0.114 * b) > 140 else (255, 255, 255, 255)
draw.text((x1 + 3, y1 + 1), tag, fill=tag_fill, font=tag_font)
body = region.get("desc", "") or ""
if etype == "text" and region.get("text"):
body = '"%s"%s' % (region["text"], "" + body if body else "")
if body and (x2 - x1) > 8:
ty = y1 + fs + 5
for line in _wrap(draw, body, font, x2 - x1 - 8):
if ty > y2:
break
draw.text((x1 + 4, ty), line, fill=readable_color((r, g, b)) + (255,), font=font)
ty += line_h
composed = Image.alpha_composite(img, overlay).convert("RGB")
arr = np.asarray(composed, dtype=np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0)
def boxes_to_regions(boxes, width: int, height: int) -> list:
regions: list = []
if not isinstance(boxes, list):
return regions
for box in boxes:
if not isinstance(box, dict):
continue
meta = box.get("metadata")
meta = meta if isinstance(meta, dict) else {}
regions.append({
**pixels_to_fractions(box, width, height),
"type": meta.get("type", "obj"),
"text": meta.get("text", ""),
"desc": meta.get("desc", ""),
"palette": meta.get("palette", []),
})
return regions
def _norm_bbox(region: dict) -> list[int]:
def grid(value: float) -> int:
return max(0, min(1000, round(value * 1000)))
x, y = region.get("x", 0.0), region.get("y", 0.0)
w, h = region.get("w", 0.0), region.get("h", 0.0)
ymin, xmin, ymax, xmax = grid(y), grid(x), grid(y + h), grid(x + w)
if ymin > ymax:
ymin, ymax = ymax, ymin
if xmin > xmax:
xmin, xmax = xmax, xmin
return [ymin, xmin, ymax, xmax]
def build_elements(regions: list) -> list:
elements = []
for region in regions:
if not isinstance(region, dict):
continue
etype = "text" if region.get("type") == "text" else "obj"
element = {"type": etype}
element["bbox"] = _norm_bbox(region)
if etype == "text":
element["text"] = region.get("text", "")
element["desc"] = region.get("desc", "")
palette = normalize_palette(region.get("palette", []))
if palette:
element["color_palette"] = palette[:5]
elements.append(element)
return elements
class CreateBoundingBoxes(io.ComfyNode):
@classmethod
def define_schema(cls):
editor_state = io.BoundingBoxes.Input(
"editor_state",
socketless=False,
tooltip="Draw bounding boxes and set each box type, text, description, color palette. Start with background element first and foreground last.",
)
return io.Schema(
node_id="CreateBoundingBoxes",
display_name="Create Bounding Boxes",
category="utilities",
description="Draw bounding boxes in a canvas. Outputs Ideogram prompt elements, pixel-space bounding boxes, and a preview image.",
inputs=[
io.Image.Input(
"background",
optional=True,
tooltip="Optional image used as background in the canvas and preview.",
),
io.Int.Input("width", default=1024, min=64, max=16384, step=16,
tooltip="Width of the canvas and the pixel grid for the bounding boxes."),
io.Int.Input("height", default=1024, min=64, max=16384, step=16,
tooltip="Height of the canvas and the pixel grid for the bounding boxes."),
editor_state,
],
outputs=[
io.Image.Output(display_name="preview"),
io.BoundingBox.Output(display_name="bboxes"),
io.Array.Output(display_name="elements"),
],
is_experimental=True,
)
@classmethod
def execute(cls, width, height, editor_state=None, background=None) -> io.NodeOutput:
regions = boxes_to_regions(editor_state, width, height)
preview = render_preview(regions, width, height, _bg_from_image(background))
return io.NodeOutput(
preview,
fractions_to_bbox_frame(regions, width, height),
build_elements(regions),
ui={"dims": [width, height]},
)
class BoundingBoxesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [CreateBoundingBoxes]
async def comfy_entrypoint() -> BoundingBoxesExtension:
return BoundingBoxesExtension()

View File

@ -1,5 +1,6 @@
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from comfy_extras.color_util import hex_to_rgb
class ColorToRGBInt(io.ComfyNode): class ColorToRGBInt(io.ComfyNode):
@ -24,9 +25,11 @@ class ColorToRGBInt(io.ComfyNode):
# expect format #RRGGBB # expect format #RRGGBB
if len(color) != 7 or color[0] != "#": if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB") raise ValueError("Color must be in format #RRGGBB")
r = int(color[1:3], 16) try:
g = int(color[3:5], 16) int(color[1:], 16)
b = int(color[5:7], 16) except ValueError:
raise ValueError("Color must be in format #RRGGBB") from None
r, g, b = hex_to_rgb(color)
rgb_int = r * 256 * 256 + g * 256 + b rgb_int = r * 256 * 256 + g * 256 + b
return io.NodeOutput(rgb_int, color) return io.NodeOutput(rgb_int, color)

View File

@ -1,85 +1,68 @@
import os import os
import sys import sys
import re import re
import ctypes
import logging import logging
import ctypes.util
import importlib.util
from typing import TypedDict from typing import TypedDict
import numpy as np import numpy as np
import torch import torch
import nodes import nodes
import comfy_angle
from comfy_api.latest import ComfyExtension, io, ui from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _check_opengl_availability(): def _preload_angle():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work.""" egl_path = comfy_angle.get_egl_path()
logger.debug("_check_opengl_availability: starting") gles_path = comfy_angle.get_glesv2_path()
missing = []
# Check Python packages (using find_spec to avoid importing) if sys.platform == "win32":
logger.debug("_check_opengl_availability: checking for glfw package") angle_dir = comfy_angle.get_lib_dir()
if importlib.util.find_spec("glfw") is None: os.add_dll_directory(angle_dir)
missing.append("glfw") os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
logger.debug("_check_opengl_availability: checking for OpenGL package") mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
if importlib.util.find_spec("OpenGL") is None: ctypes.CDLL(str(egl_path), mode=mode)
missing.append("PyOpenGL") ctypes.CDLL(str(gles_path), mode=mode)
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
# Run early check at import time # Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
logger.debug("nodes_glsl: running _check_opengl_availability at import time") # plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
_check_opengl_availability() _preload_angle()
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
def _import_opengl(): import OpenGL
"""Import OpenGL module. Called after context is created.""" OpenGL.USE_ACCELERATE = False
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
def _patch_find_library():
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
'libGLESv2'. Patch find_library to return the full ANGLE paths so
PyOpenGL loads the same libraries we pre-loaded."""
if sys.platform == "linux":
return
import ctypes.util
_orig = ctypes.util.find_library
def _patched(name):
if name == 'EGL':
return comfy_angle.get_egl_path()
if name == 'GLESv2':
return comfy_angle.get_glesv2_path()
return _orig(name)
ctypes.util.find_library = _patched
_patch_find_library()
from OpenGL import EGL
from OpenGL import GLES3 as gl
class SizeModeInput(TypedDict): class SizeModeInput(TypedDict):
size_mode: str size_mode: str
width: int width: int
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# (-1,-1)---(3,-1) # (-1,-1)---(3,-1)
# #
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1) # v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core VERTEX_SHADER = """#version 300 es
out vec2 v_texCoord; out vec2 v_texCoord;
void main() { void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3)); vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
@ -126,14 +109,99 @@ void main() {
""" """
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core.""" def _egl_attribs(*values):
# Remove any existing #version directive """Build an EGL_NONE-terminated EGLint attribute array."""
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE) vals = list(values) + [EGL.EGL_NONE]
# Remove precision qualifiers (not needed in desktop GLSL) return (ctypes.c_int32 * len(vals))(*vals)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source # EGL platform extension constants
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
_eglGetPlatformDisplayEXT = None
def _get_egl_platform_display_ext(platform, native_display, attribs):
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
global _eglGetPlatformDisplayEXT
if _eglGetPlatformDisplayEXT is None:
from OpenGL import platform as _plat
egl_lib = _plat.PLATFORM.EGL
_get_proc = egl_lib.eglGetProcAddress
_get_proc.restype = ctypes.c_void_p
_get_proc.argtypes = [ctypes.c_char_p]
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
if not ptr:
return None
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
_eglGetPlatformDisplayEXT = func_type(ptr)
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
if not raw:
return None
return ctypes.cast(raw, EGL.EGLDisplay)
def _get_egl_display():
"""Get an EGL display, trying the default first then ANGLE's Vulkan
platform for headless environments without a display server."""
failures = []
# Try the default display first (works when X11/Wayland is available)
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
if display:
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
return display, major.value, minor.value
except Exception as e:
failures.append(f"default: {e}")
logger.info("Default EGL display unavailable, trying headless fallbacks")
# Headless fallback strategies, tried in order:
headless_strategies = [
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
]
for name, platform, native_display, attribs in headless_strategies:
display = _get_egl_platform_display_ext(platform, native_display, attribs)
if not display:
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
continue
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
logger.info(f"Using EGL {name} platform (headless)")
return display, major.value, minor.value
failures.append(f"{name}: eglInitialize returned false")
except Exception as e:
failures.append(f"{name}: {e}")
continue
details = "\n".join(f" - {f}" for f in failures)
raise RuntimeError(
"Failed to initialize EGL display.\n"
"No display server and no headless EGL platform available.\n"
f"Tried:\n{details}\n"
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
)
def _gl_str(name):
"""Get an OpenGL string parameter."""
v = gl.glGetString(name)
if not v:
return "Unknown"
if isinstance(v, bytes):
return v.decode(errors="replace")
return ctypes.string_at(v).decode(errors="replace")
def _detect_output_count(source: str) -> int: def _detect_output_count(source: str) -> int:
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
return 1 return 1
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
class GLContext: class GLContext:
"""Manages OpenGL context and resources for shader execution. """Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
Tries backends in order: GLFW (desktop) EGL (headless GPU) OSMesa (software).
"""
_instance = None _instance = None
_initialized = False _initialized = False
@ -327,131 +240,105 @@ class GLContext:
def __init__(self): def __init__(self):
if GLContext._initialized: if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time import time
start = time.perf_counter() start = time.perf_counter()
self._backend = None self._display = None
self._window = None self._surface = None
self._egl_display = None self._context = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._vao = None self._vao = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try: try:
self._window, glfw = _init_glfw() self._display, self._egl_major, self._egl_minor = _get_egl_display()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
if self._backend is None: if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
logger.debug("GLContext.__init__: trying EGL backend") raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if self._backend is None: config = EGL.EGLConfig()
logger.debug("GLContext.__init__: trying OSMesa backend") n_configs = ctypes.c_int32(0)
try: if not EGL.eglChooseConfig(
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa() self._display,
self._backend = "osmesa" _egl_attribs(
logger.debug("GLContext.__init__: OSMesa backend succeeded") EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
except Exception as e: EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}") EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
errors.append(("OSMesa", e)) EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
),
ctypes.byref(config), 1, ctypes.byref(n_configs),
) or n_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
if self._backend is None: self._surface = EGL.eglCreatePbufferSurface(
if sys.platform == "win32": self._display, config,
platform_help = ( _egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: GLFW is not supported.\n"
" Install OSMesa via Homebrew: brew install mesa\n"
" Then: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
) )
if not self._surface:
raise RuntimeError("eglCreatePbufferSurface() failed")
# Now import OpenGL.GL (after context is current) self._context = EGL.eglCreateContext(
logger.debug("GLContext.__init__: importing OpenGL.GL") self._display, config, EGL.EGL_NO_CONTEXT,
_import_opengl() _egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
)
if not self._context:
raise RuntimeError("eglCreateContext() failed")
# Create VAO (required for core profile, but OSMesa may use compat profile) if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
logger.debug("GLContext.__init__: creating VAO") raise RuntimeError("eglMakeCurrent() failed")
try:
vao = gl.glGenVertexArrays(1) self._vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao) gl.glBindVertexArray(self._vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully") except Exception:
except Exception as e: self._cleanup()
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}") raise
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
elapsed = (time.perf_counter() - start) * 1000 elapsed = (time.perf_counter() - start) * 1000
# Log device info renderer = _gl_str(gl.GL_RENDERER)
renderer = gl.glGetString(gl.GL_RENDERER) vendor = _gl_str(gl.GL_VENDOR)
vendor = gl.glGetString(gl.GL_VENDOR) version = _gl_str(gl.GL_VERSION)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
GLContext._initialized = True GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}") logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
def make_current(self): def make_current(self):
if self._backend == "glfw": if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
glfw.make_context_current(self._window) err = EGL.eglGetError()
elif self._backend == "egl": raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if self._vao is not None: if self._vao is not None:
gl.glBindVertexArray(self._vao) gl.glBindVertexArray(self._vao)
def _cleanup(self):
if not self._display:
return
try:
if self._vao is not None:
gl.glDeleteVertexArrays(1, [self._vao])
self._vao = None
except Exception:
pass
try:
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
except Exception:
pass
try:
if self._context:
EGL.eglDestroyContext(self._display, self._context)
except Exception:
pass
try:
if self._surface:
EGL.eglDestroySurface(self._display, self._surface)
except Exception:
pass
try:
EGL.eglTerminate(self._display)
except Exception:
pass
self._display = None
def _compile_shader(source: str, shader_type: int) -> int: def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID.""" """Compile a shader and return its ID."""
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
gl.glShaderSource(shader, source) gl.glShaderSource(shader, source)
gl.glCompileShader(shader) gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
error = gl.glGetShaderInfoLog(shader).decode() error = gl.glGetShaderInfoLog(shader)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteShader(shader) gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}") raise RuntimeError(f"Shader compilation failed:\n{error}")
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
gl.glDeleteShader(vertex_shader) gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader) gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE: if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
error = gl.glGetProgramInfoLog(program).decode() error = gl.glGetProgramInfoLog(program)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteProgram(program) gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}") raise RuntimeError(f"Program linking failed:\n{error}")
@ -530,9 +421,6 @@ def _render_shader_batch(
ctx = GLContext() ctx = GLContext()
ctx.make_current() ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses # Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code) num_outputs = _detect_output_count(fragment_code)
@ -558,9 +446,9 @@ def _render_shader_batch(
try: try:
# Compile shaders (once for all batches) # Compile shaders (once for all batches)
try: try:
program = _create_program(VERTEX_SHADER, fragment_source) program = _create_program(VERTEX_SHADER, fragment_code)
except RuntimeError: except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}") logger.error(f"Fragment shader:\n{fragment_code}")
raise raise
gl.glUseProgram(program) gl.glUseProgram(program)
@ -723,13 +611,13 @@ def _render_shader_batch(
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch # Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering) gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
batch_outputs = [] batch_outputs = []
for tex in output_textures: for i in range(num_outputs):
gl.glBindTexture(gl.GL_TEXTURE_2D, tex) gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT) buf = np.empty((height, width, 4), dtype=np.float32)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4) gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
batch_outputs.append(img[::-1, :, :].copy()) batch_outputs.append(buf[::-1, :, :].copy())
# Pad with black images for unused outputs # Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32) black_img = np.zeros((height, width, 4), dtype=np.float32)
@ -750,18 +638,18 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0) gl.glUseProgram(0)
for tex in input_textures: if input_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(len(input_textures), input_textures)
for tex in curve_textures: if curve_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(len(curve_textures), curve_textures)
for tex in output_textures: if output_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(len(output_textures), output_textures)
for tex in ping_pong_textures: if ping_pong_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
if fbo is not None: if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo]) gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos: if ping_pong_fbos:
gl.glDeleteFramebuffers(1, [pp_fbo]) gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
if program is not None: if program is not None:
gl.glDeleteProgram(program) gl.glDeleteProgram(program)

View File

@ -0,0 +1,77 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.color_util import normalize_palette
class BuildJsonPromptIdeogram(io.ComfyNode):
@classmethod
def define_schema(cls):
color_palette = io.Colors.Input(
"color_palette",
socketless=False,
tooltip="Hex color codes that steer the image's dominant colors. Up to 16 entries.",
)
return io.Schema(
node_id="BuildJsonPromptIdeogram",
display_name="Build JSON Prompt (Ideogram)",
category="text",
description="Build a JSON prompt for the Ideogram 4 model.",
inputs=[
io.Array.Input("element", tooltip="Prompt elements from the node Create Bounding Boxes."),
io.String.Input("high_level_description", multiline=True, default="",
tooltip="Optional description of the image in one or two sentences. Strongly recommended."),
io.String.Input("background", multiline=True, default="",
tooltip="Mandatory description of the image background or environment."),
io.DynamicCombo.Input("style", options=[
io.DynamicCombo.Option("none", []),
io.DynamicCombo.Option("photo", [io.String.Input("photo", default="", tooltip="Camera or lens details for photographic outputs (e.g. 35mm, f/1.4, bokeh).")]),
io.DynamicCombo.Option("art_style", [io.String.Input("art_style", default="", tooltip="Art style description (e.g. flat vector illustration, bold outlines).")]),
]),
io.String.Input("aesthetics", default="", tooltip="Mandatory aesthetic keywords (e.g. moody, cinematic, desaturated)."),
io.String.Input("lighting", default="", tooltip="Mandatory lighting description (e.g. golden hour, rim light, dramatic shadows)."),
io.String.Input("medium", default="", tooltip="Mandatory medium type (e.g. photograph, illustration, 3d_render, painting, graphic_design). When style = photo, set to photograph."),
color_palette,
],
outputs=[io.Dict.Output(display_name="prompt")],
is_experimental=True,
)
@classmethod
def execute(cls, element, style, high_level_description="", background="",
aesthetics="", lighting="", medium="", color_palette=None) -> io.NodeOutput:
elements = element if isinstance(element, list) else []
kind = style.get("style", "none") if isinstance(style, dict) else "none"
photo = style.get("photo", "") if isinstance(style, dict) else ""
art_style = style.get("art_style", "") if isinstance(style, dict) else ""
palette = normalize_palette(color_palette or [])
caption: dict = {}
if high_level_description.strip():
caption["high_level_description"] = high_level_description
if kind != "none":
style_desc: dict = {"aesthetics": aesthetics, "lighting": lighting}
if kind == "photo":
style_desc["photo"] = photo
style_desc["medium"] = medium
else:
style_desc["medium"] = medium
style_desc["art_style"] = art_style
if palette:
style_desc["color_palette"] = palette
caption["style_description"] = style_desc
caption["compositional_deconstruction"] = {
"background": background,
"elements": elements,
}
return io.NodeOutput(caption)
class JsonPromptExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [BuildJsonPromptIdeogram]
async def comfy_entrypoint() -> JsonPromptExtension:
return JsonPromptExtension()

View File

@ -337,6 +337,36 @@ class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict} return {"required": arg_dict}
class ModelMergeKrea2(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "model/merging/model specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["first."] = argument
arg_dict["tmlp."] = argument
arg_dict["txtmlp."] = argument
arg_dict["tproj."] = argument
for i in range(2):
arg_dict["txtfusion.layerwise_blocks.{}.".format(i)] = argument
arg_dict["txtfusion.projector."] = argument
for i in range(2):
arg_dict["txtfusion.refiner_blocks.{}.".format(i)] = argument
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["last."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1, "ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@ -353,4 +383,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B, "ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B, "ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
"ModelMergeQwenImage": ModelMergeQwenImage, "ModelMergeQwenImage": ModelMergeQwenImage,
"ModelMergeKrea2": ModelMergeKrea2,
} }

View File

@ -0,0 +1,33 @@
import sys
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SeedNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedNode",
display_name="Seed",
search_aliases=["seed", "random"],
category="utilities",
inputs=[
io.Int.Input("seed", min=0, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
],
outputs=[io.Int.Output(display_name="seed")],
)
@classmethod
def execute(cls, seed: int) -> io.NodeOutput:
return io.NodeOutput(seed)
class SeedExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [SeedNode]
async def comfy_entrypoint() -> SeedExtension:
return SeedExtension()

View File

@ -440,6 +440,57 @@ class JsonExtractString(io.ComfyNode):
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
return io.NodeOutput("") return io.NodeOutput("")
def _dump_json(value, indent):
return json.dumps(value, ensure_ascii=False, indent=indent or None)
class ConvertDictionaryToString(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConvertDictionaryToString",
display_name="Convert Dictionary to String",
category="text",
search_aliases=["json", "dict to json", "stringify", "serialize", "dict to string"],
inputs=[
io.Dict.Input("dictionary"),
io.Int.Input("indent", default=2, min=0, max=8,
tooltip="Spaces per indent level. 0 produces compact single-line string."),
],
outputs=[
io.String.Output(),
],
)
@classmethod
def execute(cls, dictionary, indent=2):
return io.NodeOutput(_dump_json(dictionary, indent))
class ConvertArrayToString(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConvertArrayToString",
display_name="Convert Array to String",
category="text",
search_aliases=["json", "list to json", "stringify", "serialize", "list to string", "array to json"],
inputs=[
io.Array.Input("array"),
io.Int.Input("indent", default=2, min=0, max=8,
tooltip="Spaces per indent level. 0 produces compact single-line string."),
],
outputs=[
io.String.Output(),
],
)
@classmethod
def execute(cls, array, indent=2):
return io.NodeOutput(_dump_json(array, indent))
class StringExtension(ComfyExtension): class StringExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -457,6 +508,8 @@ class StringExtension(ComfyExtension):
RegexExtract, RegexExtract,
RegexReplace, RegexReplace,
JsonExtractString, JsonExtractString,
ConvertDictionaryToString,
ConvertArrayToString,
] ]
async def comfy_entrypoint() -> StringExtension: async def comfy_entrypoint() -> StringExtension:

View File

@ -2387,6 +2387,8 @@ async def init_builtin_extra_nodes():
"nodes_images.py", "nodes_images.py",
"nodes_video_model.py", "nodes_video_model.py",
"nodes_ideogram4.py", "nodes_ideogram4.py",
"nodes_bounding_boxes.py",
"nodes_json_prompt.py",
"nodes_train.py", "nodes_train.py",
"nodes_dataset.py", "nodes_dataset.py",
"nodes_sag.py", "nodes_sag.py",
@ -2486,6 +2488,7 @@ async def init_builtin_extra_nodes():
"nodes_gaussian_splat.py", "nodes_gaussian_splat.py",
"nodes_triposplat.py", "nodes_triposplat.py",
"nodes_depth_anything_3.py", "nodes_depth_anything_3.py",
"nodes_seed.py",
] ]
import_failed = [] import_failed = []

View File

@ -1692,6 +1692,12 @@ paths:
schema: schema:
$ref: '#/components/schemas/ErrorResponse' $ref: '#/components/schemas/ErrorResponse'
description: Unsupported media type description: Unsupported media type
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Validation error (e.g., disallowed model_type tag)
"500": "500":
content: content:
application/json: application/json:
@ -2137,6 +2143,12 @@ paths:
schema: schema:
$ref: '#/components/schemas/ErrorResponse' $ref: '#/components/schemas/ErrorResponse'
description: Source asset with given hash not found description: Source asset with given hash not found
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Validation error (e.g., disallowed model_type tag)
"500": "500":
content: content:
application/json: application/json:
@ -2992,7 +3004,7 @@ paths:
format: uuid format: uuid
type: string type: string
- description: | - description: |
When present, each output item in the response receives a `short_url` field containing an owner-gated durable link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime: use `ephemeral_tool_chain` for short-lived machine-to-machine handoffs (~15 minutes); use `default` for durable human-revisitable links (30 days). Links are minted only for the authenticated request owner and are not resolvable by other users. When present, each output item in the response receives a `short_url` field containing a short link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime and auth model: use `ephemeral_tool_chain` for short-lived (≤5 minute) machine-to-machine handoffs — these are public bearer links where the link ID itself is the credential, so anyone holding the link can resolve it (intended for pasting into an agent/MCP tool chain); use `default` for durable (30 day) human-revisitable links, which are owner-gated and resolvable only by the authenticated owner. Links are always minted under the authenticated request owner's identity; the auth model is selected by the server and is never settable by the caller.
in: query in: query
name: short_link name: short_link
schema: schema:

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.45.19 comfyui-frontend-package==1.45.19
comfyui-workflow-templates==0.10.2 comfyui-workflow-templates==0.10.7
comfyui-embedded-docs==0.5.5 comfyui-embedded-docs==0.5.5
torch torch
torchsde torchsde
@ -22,7 +22,7 @@ alembic
SQLAlchemy>=2.0.0 SQLAlchemy>=2.0.0
filelock filelock
av>=16.0.0 av>=16.0.0
comfy-kitchen==0.2.10 comfy-kitchen==0.2.14
comfy-aimdo==0.4.10 comfy-aimdo==0.4.10
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
@ -33,5 +33,5 @@ kornia>=0.7.1
spandrel spandrel
pydantic~=2.0 pydantic~=2.0
pydantic-settings~=2.0 pydantic-settings~=2.0
PyOpenGL PyOpenGL>=3.1.8
glfw comfy-angle

View File

@ -228,6 +228,62 @@ class TestMixedPrecisionOps(unittest.TestCase):
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
def test_int8_convrot_metadata_loads_into_params(self):
"""ConvRot metadata must reach TensorWiseINT8Layout params."""
torch.manual_seed(123)
layer_quant_config = {
"layer": {
"format": "int8_tensorwise",
"convrot": True,
"convrot_groupsize": 256,
}
}
weight = torch.randn(16, 256, dtype=torch.bfloat16)
bias = torch.randn(16, dtype=torch.bfloat16)
q_weight = QuantizedTensor.from_float(
weight,
"TensorWiseINT8Layout",
per_channel=True,
convrot=True,
convrot_groupsize=256,
)
state_dict = {
"layer.weight": q_weight._qdata,
"layer.bias": bias,
"layer.weight_scale": q_weight._params.scale,
}
state_dict, _ = comfy.utils.convert_old_quants(
state_dict,
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
)
model = torch.nn.Module()
model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16)
model.load_state_dict(state_dict, strict=False)
self.assertIsInstance(model.layer.weight, QuantizedTensor)
self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout")
self.assertTrue(model.layer.weight._params.convrot)
self.assertEqual(model.layer.weight._params.convrot_groupsize, 256)
input_tensor = torch.randn(4, 256, dtype=torch.bfloat16)
loaded_out = model.layer(input_tensor)
ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias)
self.assertTrue(torch.equal(loaded_out, ref_out))
fp16_input = input_tensor.to(torch.float16)
loaded_fp16_out = model.layer(fp16_input)
ref_fp16_out = torch.nn.functional.linear(
fp16_input,
q_weight.to(dtype=torch.float16),
bias.to(dtype=torch.float16),
)
self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out))
saved = model.state_dict()
saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes())
self.assertTrue(saved_conf["convrot"])
self.assertEqual(saved_conf["convrot_groupsize"], 256)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()