Compare commits

..

1 Commits

Author SHA1 Message Date
Alexis Rolland
2f9d725f4b
Merge 70ebe28cd6 into 025e6792ee 2026-05-03 08:55:29 -07:00
5 changed files with 40 additions and 44 deletions

View File

@ -91,7 +91,6 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
class LatentPreviewMethod(enum.Enum): class LatentPreviewMethod(enum.Enum):
NoPreviews = "none" NoPreviews = "none"

View File

@ -1,8 +1,6 @@
import torch import torch
import logging import logging
from comfy.cli_args import args
try: try:
import comfy_kitchen as ck import comfy_kitchen as ck
from comfy_kitchen.tensor import ( from comfy_kitchen.tensor import (
@ -23,15 +21,7 @@ try:
ck.registry.disable("cuda") ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
if args.enable_triton_backend: ck.registry.disable("triton")
try:
import triton
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
except ImportError as e:
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
ck.registry.disable("triton")
else:
ck.registry.disable("triton")
for k, v in ck.list_backends().items(): for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}") logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e: except ImportError as e:

View File

@ -666,13 +666,12 @@ class ColorTransfer(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="ColorTransfer", node_id="ColorTransfer",
display_name="Color Transfer",
category="image/postprocessing", category="image/postprocessing",
description="Match the colors of one image to another using various algorithms.", description="Match the colors of one image to another using various algorithms.",
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
inputs=[ inputs=[
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."), io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
io.DynamicCombo.Input("source_stats", io.DynamicCombo.Input("source_stats",
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",

View File

@ -49,7 +49,7 @@ class Int(io.ComfyNode):
display_name="Int", display_name="Int",
category="utils/primitive", category="utils/primitive",
inputs=[ inputs=[
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed), io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
], ],
outputs=[io.Int.Output()], outputs=[io.Int.Output()],
) )

View File

@ -1754,49 +1754,57 @@ class LoadImage:
return True return True
class LoadImageMask:
class LoadImageMask(LoadImage):
ESSENTIALS_CATEGORY = "Image Tools" ESSENTIALS_CATEGORY = "Image Tools"
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"] _color_channels = ["alpha", "red", "green", "blue"]
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
types = super().INPUT_TYPES() input_dir = folder_paths.get_input_directory()
return { files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
"required": { return {"required":
**types["required"], {"image": (sorted(files), {"image_upload": True}),
"channel": (s._color_channels, ) "channel": (s._color_channels, ), }
} }
}
CATEGORY = "mask" CATEGORY = "mask"
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image_mask" FUNCTION = "load_image"
def load_image(self, image, channel):
def load_image_mask(self, image, channel): image_path = folder_paths.get_annotated_filepath(image)
image_tensor, mask_tensor = super().load_image(image) i = node_helpers.pillow(Image.open, image_path)
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA")
mask = None
c = channel[0].upper() c = channel[0].upper()
if c in i.getbands():
if c == 'A': mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
return (mask_tensor,) mask = torch.from_numpy(mask)
if c == 'A':
channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0) mask = 1. - mask
if channel_idx < image_tensor.shape[-1]:
return (image_tensor[..., channel_idx].clone(),)
else: else:
empty_mask = torch.zeros( mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
image_tensor.shape[:-1], return (mask.unsqueeze(0),)
dtype=image_tensor.dtype,
device=image_tensor.device
)
return (empty_mask,)
@classmethod @classmethod
def IS_CHANGED(s, image, channel): def IS_CHANGED(s, image, channel):
return super().IS_CHANGED(image) image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
class LoadImageOutput(LoadImage): class LoadImageOutput(LoadImage):