convert nodes_torch_compile.py to V3 schema (#10173)

This commit is contained in:
Alexander Piskun 2025-10-03 21:43:54 +03:00 committed by GitHub
parent c2c5a7d5f8
commit 3e68bc342c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,23 +1,39 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper from comfy_api.torch_helpers import set_torch_compile_wrapper
class TorchCompileModel: class TorchCompileModel(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": { "model": ("MODEL",), return io.Schema(
"backend": (["inductor", "cudagraphs"],), node_id="TorchCompileModel",
}} category="_for_testing",
RETURN_TYPES = ("MODEL",) inputs=[
FUNCTION = "patch" io.Model.Input("model"),
io.Combo.Input(
"backend",
options=["inductor", "cudagraphs"],
),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
CATEGORY = "_for_testing" @classmethod
EXPERIMENTAL = True def execute(cls, model, backend) -> io.NodeOutput:
def patch(self, model, backend):
m = model.clone() m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend) set_torch_compile_wrapper(model=m, backend=backend)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel, class TorchCompileExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TorchCompileModel,
]
async def comfy_entrypoint() -> TorchCompileExtension:
return TorchCompileExtension()