mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
convert nodes_gits.py to V3 schema (#9949)
This commit is contained in:
parent
638097829d
commit
7eb7160db4
@ -1,6 +1,8 @@
|
|||||||
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
|
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
def loglinear_interp(t_steps, num_steps):
|
def loglinear_interp(t_steps, num_steps):
|
||||||
"""
|
"""
|
||||||
@ -333,25 +335,28 @@ NOISE_LEVELS = {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
class GITSScheduler:
|
class GITSScheduler(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required":
|
return io.Schema(
|
||||||
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}),
|
node_id="GITSScheduler",
|
||||||
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}),
|
category="sampling/custom_sampling/schedulers",
|
||||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
inputs=[
|
||||||
}
|
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
|
||||||
}
|
io.Int.Input("steps", default=10, min=2, max=1000),
|
||||||
RETURN_TYPES = ("SIGMAS",)
|
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Sigmas.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "get_sigmas"
|
@classmethod
|
||||||
|
def execute(cls, coeff, steps, denoise):
|
||||||
def get_sigmas(self, coeff, steps, denoise):
|
|
||||||
total_steps = steps
|
total_steps = steps
|
||||||
if denoise < 1.0:
|
if denoise < 1.0:
|
||||||
if denoise <= 0.0:
|
if denoise <= 0.0:
|
||||||
return (torch.FloatTensor([]),)
|
return io.NodeOutput(torch.FloatTensor([]))
|
||||||
total_steps = round(steps * denoise)
|
total_steps = round(steps * denoise)
|
||||||
|
|
||||||
if steps <= 20:
|
if steps <= 20:
|
||||||
@ -362,8 +367,16 @@ class GITSScheduler:
|
|||||||
|
|
||||||
sigmas = sigmas[-(total_steps + 1):]
|
sigmas = sigmas[-(total_steps + 1):]
|
||||||
sigmas[-1] = 0
|
sigmas[-1] = 0
|
||||||
return (torch.FloatTensor(sigmas), )
|
return io.NodeOutput(torch.FloatTensor(sigmas))
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"GITSScheduler": GITSScheduler,
|
class GITSSchedulerExtension(ComfyExtension):
|
||||||
}
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
GITSScheduler,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> GITSSchedulerExtension:
|
||||||
|
return GITSSchedulerExtension()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user