convert nodes_model_downscale.py to V3 schema (#10199)

This commit is contained in:
Alexander Piskun 2025-10-09 09:10:23 +03:00 committed by GitHub
parent 51fb505ffa
commit 2ba8d7cce8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,24 +1,33 @@
from typing_extensions import override
import comfy.utils import comfy.utils
from comfy_api.latest import ComfyExtension, io
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] class PatchModelAddDownscale(io.ComfyNode):
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), node_id="PatchModelAddDownscale",
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), category="model_patches/unet",
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), inputs=[
"downscale_after_skip": ("BOOLEAN", {"default": True}), io.Model.Input("model"),
"downscale_method": (s.upscale_methods,), io.Int.Input("block_number", default=3, min=1, max=32, step=1),
"upscale_method": (s.upscale_methods,), io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
}} io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
RETURN_TYPES = ("MODEL",) io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
FUNCTION = "patch" io.Boolean.Input("downscale_after_skip", default=True),
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "model_patches/unet" @classmethod
def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
model_sampling = model.get_model_object("model_sampling") model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent) sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent) sigma_end = model_sampling.percent_to_sigma(end_percent)
@ -41,13 +50,21 @@ class PatchModelAddDownscale:
else: else:
m.set_model_input_block_patch(input_block_patch) m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch) m.set_model_output_block_patch(output_block_patch)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling # Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", "PatchModelAddDownscale": "",
} }
class ModelDownscaleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PatchModelAddDownscale,
]
async def comfy_entrypoint() -> ModelDownscaleExtension:
return ModelDownscaleExtension()