convert nodes_gits.py to V3 schema (#9949)

This commit is contained in:
Alexander Piskun 2025-10-01 22:16:59 +03:00 committed by GitHub
parent 638097829d
commit 7eb7160db4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,8 @@
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main # from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
import numpy as np import numpy as np
import torch import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps): def loglinear_interp(t_steps, num_steps):
""" """
@ -333,25 +335,28 @@ NOISE_LEVELS = {
], ],
} }
class GITSScheduler: class GITSScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}), node_id="GITSScheduler",
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}), category="sampling/custom_sampling/schedulers",
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), inputs=[
} io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
} io.Int.Input("steps", default=10, min=2, max=1000),
RETURN_TYPES = ("SIGMAS",) io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
CATEGORY = "sampling/custom_sampling/schedulers" ],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas" @classmethod
def execute(cls, coeff, steps, denoise):
def get_sigmas(self, coeff, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0: if denoise <= 0.0:
return (torch.FloatTensor([]),) return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise) total_steps = round(steps * denoise)
if steps <= 20: if steps <= 20:
@ -362,8 +367,16 @@ class GITSScheduler:
sigmas = sigmas[-(total_steps + 1):] sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0 sigmas[-1] = 0
return (torch.FloatTensor(sigmas), ) return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"GITSScheduler": GITSScheduler, class GITSSchedulerExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
GITSScheduler,
]
async def comfy_entrypoint() -> GITSSchedulerExtension:
return GITSSchedulerExtension()