mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 09:47:35 +08:00
CURVE node
This commit is contained in:
parent
b6ddc590ed
commit
4c11d5f75f
@ -1049,6 +1049,48 @@ class ManualSigmas(io.ComfyNode):
|
|||||||
sigmas = torch.FloatTensor(sigmas)
|
sigmas = torch.FloatTensor(sigmas)
|
||||||
return io.NodeOutput(sigmas)
|
return io.NodeOutput(sigmas)
|
||||||
|
|
||||||
|
class CurveToSigmas(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CurveToSigmas",
|
||||||
|
display_name="Curve to Sigmas",
|
||||||
|
category="sampling/custom_sampling/schedulers",
|
||||||
|
inputs=[
|
||||||
|
io.Curve.Input("curve", default=[[0.0, 1.0], [1.0, 0.0]]),
|
||||||
|
io.Model.Input("model", optional=True),
|
||||||
|
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||||
|
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, optional=True),
|
||||||
|
],
|
||||||
|
outputs=[io.Sigmas.Output()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, curve, steps, sigma_max=14.614642, model=None) -> io.NodeOutput:
|
||||||
|
points = sorted(curve, key=lambda p: p[0])
|
||||||
|
model_sampling = model.get_model_object("model_sampling") if model is not None else None
|
||||||
|
|
||||||
|
sigmas = []
|
||||||
|
for i in range(steps + 1):
|
||||||
|
t = i / steps
|
||||||
|
y = points[0][1] if t < points[0][0] else points[-1][1]
|
||||||
|
for j in range(len(points) - 1):
|
||||||
|
if points[j][0] <= t <= points[j + 1][0]:
|
||||||
|
x0, y0 = points[j]
|
||||||
|
x1, y1 = points[j + 1]
|
||||||
|
y = y0 if x1 == x0 else y0 + (y1 - y0) * (t - x0) / (x1 - x0)
|
||||||
|
break
|
||||||
|
if model_sampling is not None:
|
||||||
|
sigmas.append(float(model_sampling.percent_to_sigma(1.0 - y)))
|
||||||
|
else:
|
||||||
|
sigmas.append(y * sigma_max)
|
||||||
|
|
||||||
|
sigmas[-1] = 0.0
|
||||||
|
return io.NodeOutput(torch.FloatTensor(sigmas))
|
||||||
|
|
||||||
|
get_sigmas = execute
|
||||||
|
|
||||||
|
|
||||||
class CustomSamplersExtension(ComfyExtension):
|
class CustomSamplersExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -1088,6 +1130,7 @@ class CustomSamplersExtension(ComfyExtension):
|
|||||||
AddNoise,
|
AddNoise,
|
||||||
SamplerCustomAdvanced,
|
SamplerCustomAdvanced,
|
||||||
ManualSigmas,
|
ManualSigmas,
|
||||||
|
CurveToSigmas,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
20
nodes.py
20
nodes.py
@ -2034,6 +2034,24 @@ class ImagePadForOutpaint:
|
|||||||
return (new_image, mask.unsqueeze(0))
|
return (new_image, mask.unsqueeze(0))
|
||||||
|
|
||||||
|
|
||||||
|
class CurveEditor:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"curve": ("CURVE", {"default": [[0, 0], [1, 1]]}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CURVE",)
|
||||||
|
RETURN_NAMES = ("curve",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "utils"
|
||||||
|
|
||||||
|
def execute(self, curve):
|
||||||
|
return (curve,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"KSampler": KSampler,
|
"KSampler": KSampler,
|
||||||
"CheckpointLoaderSimple": CheckpointLoaderSimple,
|
"CheckpointLoaderSimple": CheckpointLoaderSimple,
|
||||||
@ -2102,6 +2120,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ConditioningZeroOut": ConditioningZeroOut,
|
"ConditioningZeroOut": ConditioningZeroOut,
|
||||||
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
|
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
|
||||||
"LoraLoaderModelOnly": LoraLoaderModelOnly,
|
"LoraLoaderModelOnly": LoraLoaderModelOnly,
|
||||||
|
"CurveEditor": CurveEditor,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -2170,6 +2189,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
# _for_testing
|
# _for_testing
|
||||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
|
"CurveEditor": "Curve Editor",
|
||||||
}
|
}
|
||||||
|
|
||||||
EXTENSION_WEB_DIRS = {}
|
EXTENSION_WEB_DIRS = {}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user