mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
convert nodes_gits.py to V3 schema (#9949)
This commit is contained in:
parent
638097829d
commit
7eb7160db4
@ -1,6 +1,8 @@
|
||||
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
def loglinear_interp(t_steps, num_steps):
|
||||
"""
|
||||
@ -333,25 +335,28 @@ NOISE_LEVELS = {
|
||||
],
|
||||
}
|
||||
|
||||
class GITSScheduler:
|
||||
class GITSScheduler(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}),
|
||||
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GITSScheduler",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
|
||||
io.Int.Input("steps", default=10, min=2, max=1000),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Sigmas.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, coeff, steps, denoise):
|
||||
@classmethod
|
||||
def execute(cls, coeff, steps, denoise):
|
||||
total_steps = steps
|
||||
if denoise < 1.0:
|
||||
if denoise <= 0.0:
|
||||
return (torch.FloatTensor([]),)
|
||||
return io.NodeOutput(torch.FloatTensor([]))
|
||||
total_steps = round(steps * denoise)
|
||||
|
||||
if steps <= 20:
|
||||
@ -362,8 +367,16 @@ class GITSScheduler:
|
||||
|
||||
sigmas = sigmas[-(total_steps + 1):]
|
||||
sigmas[-1] = 0
|
||||
return (torch.FloatTensor(sigmas), )
|
||||
return io.NodeOutput(torch.FloatTensor(sigmas))
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"GITSScheduler": GITSScheduler,
|
||||
}
|
||||
|
||||
class GITSSchedulerExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
GITSScheduler,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> GITSSchedulerExtension:
|
||||
return GITSSchedulerExtension()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user