mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-29 15:50:22 +08:00
Remove nodes_color
This commit is contained in:
parent
ba600832be
commit
6ade1d77e6
@ -1,185 +0,0 @@
|
|||||||
import colour
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from scipy.interpolate import interpn
|
|
||||||
from scipy.interpolate.interpnd import LinearNDInterpolator
|
|
||||||
from scipy.ndimage import gaussian_filter
|
|
||||||
from scipy.optimize._lsap import linear_sum_assignment
|
|
||||||
from sklearn.cluster import KMeans
|
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
from comfy.nodes.package_typing import CustomNode
|
|
||||||
|
|
||||||
|
|
||||||
class ColorPaletteExtractor(CustomNode):
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"image": ("IMAGE",),
|
|
||||||
"num_colors": ("INT", {"default": 8, "min": 2, "max": 64, "step": 1}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "PALETTE")
|
|
||||||
FUNCTION = "extract_palette"
|
|
||||||
CATEGORY = "image/color"
|
|
||||||
|
|
||||||
def extract_palette(self, image, num_colors):
|
|
||||||
if len(image.shape) == 3:
|
|
||||||
image = image.unsqueeze(0)
|
|
||||||
|
|
||||||
image_np = image.squeeze().cpu().numpy()
|
|
||||||
|
|
||||||
pixels = image_np.reshape(-1, 3)
|
|
||||||
|
|
||||||
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
|
|
||||||
kmeans.fit(pixels)
|
|
||||||
|
|
||||||
colors = kmeans.cluster_centers_
|
|
||||||
|
|
||||||
_, counts = np.unique(kmeans.labels_, return_counts=True)
|
|
||||||
sorted_indices = np.argsort(counts)[::-1]
|
|
||||||
colors = colors[sorted_indices]
|
|
||||||
|
|
||||||
palette_height = 512
|
|
||||||
palette_width = 512
|
|
||||||
palette_image = np.zeros((palette_height, palette_width * num_colors, 3), dtype=np.uint8)
|
|
||||||
for i, color in enumerate(colors):
|
|
||||||
palette_image[:, i * palette_width:(i + 1) * palette_width] = color
|
|
||||||
|
|
||||||
palette_tensor = torch.from_numpy(palette_image).float() / 255.0
|
|
||||||
palette_tensor = palette_tensor.unsqueeze(0).to(comfy.model_management.get_torch_device())
|
|
||||||
|
|
||||||
color_array = (colors * 255).astype(np.uint8)
|
|
||||||
|
|
||||||
return palette_tensor, color_array
|
|
||||||
|
|
||||||
|
|
||||||
class ImageBasedColorRemap(CustomNode):
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"image": ("IMAGE",),
|
|
||||||
"reference_image": ("IMAGE",),
|
|
||||||
"palette_size": ("INT", {"default": 8, "min": 2, "max": 64, "step": 1}),
|
|
||||||
"lut_size": ("INT", {"default": 33, "min": 8, "max": 64, "step": 1}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "remap_colors"
|
|
||||||
CATEGORY = "image/color"
|
|
||||||
|
|
||||||
def remap_colors(self, image, reference_image, palette_size, lut_size):
|
|
||||||
# Ensure images are in the correct format (B, H, W, C)
|
|
||||||
image = image.squeeze(0) if image.ndim == 4 else image
|
|
||||||
reference_image = reference_image.squeeze(0) if reference_image.ndim == 4 else reference_image
|
|
||||||
|
|
||||||
# Convert torch tensors to numpy arrays
|
|
||||||
image_np = image.cpu().numpy()
|
|
||||||
reference_np = reference_image.cpu().numpy()
|
|
||||||
|
|
||||||
# Extract palettes using k-means clustering
|
|
||||||
image_palette = self.extract_palette(image_np, palette_size)
|
|
||||||
reference_palette = self.extract_palette(reference_np, palette_size)
|
|
||||||
|
|
||||||
# Convert palettes to LAB color space
|
|
||||||
image_palette_lab = colour.XYZ_to_Lab(colour.sRGB_to_XYZ(image_palette))
|
|
||||||
reference_palette_lab = colour.XYZ_to_Lab(colour.sRGB_to_XYZ(reference_palette))
|
|
||||||
|
|
||||||
# Align palettes
|
|
||||||
aligned_image_palette_lab, aligned_reference_palette_lab = self.align_palettes(image_palette_lab, reference_palette_lab)
|
|
||||||
|
|
||||||
# Create 3D LUT in LAB space
|
|
||||||
lut = self.create_color_remap_3dlut(aligned_image_palette_lab, aligned_reference_palette_lab, size=lut_size)
|
|
||||||
|
|
||||||
# Apply 3D LUT
|
|
||||||
image_lab = colour.XYZ_to_Lab(colour.sRGB_to_XYZ(image_np))
|
|
||||||
remapped_lab = self.apply_lut(image_lab, lut, lut_size)
|
|
||||||
remapped_rgb = colour.XYZ_to_sRGB(colour.Lab_to_XYZ(remapped_lab))
|
|
||||||
|
|
||||||
# Clip values to ensure they're in the valid range
|
|
||||||
remapped_rgb = np.clip(remapped_rgb, 0, 1)
|
|
||||||
|
|
||||||
# Convert back to torch tensor
|
|
||||||
remapped_tensor = torch.from_numpy(remapped_rgb).float()
|
|
||||||
remapped_tensor = remapped_tensor.unsqueeze(0) # Add batch dimension
|
|
||||||
|
|
||||||
return (remapped_tensor,)
|
|
||||||
|
|
||||||
def extract_palette(self, image, palette_size):
|
|
||||||
pixels = image.reshape(-1, 3)
|
|
||||||
kmeans = KMeans(n_clusters=palette_size, random_state=42, n_init=10)
|
|
||||||
kmeans.fit(pixels)
|
|
||||||
return kmeans.cluster_centers_
|
|
||||||
|
|
||||||
def align_palettes(self, palette1, palette2):
|
|
||||||
distances = np.linalg.norm(palette1[:, np.newaxis] - palette2, axis=2)
|
|
||||||
row_ind, col_ind = linear_sum_assignment(distances)
|
|
||||||
return palette1, palette2[col_ind]
|
|
||||||
|
|
||||||
def create_color_remap_3dlut(self, original_colors, target_colors, size=33):
|
|
||||||
# Create a regular grid in LAB space
|
|
||||||
L = np.linspace(0, 100, size)
|
|
||||||
a = np.linspace(-128, 127, size)
|
|
||||||
b = np.linspace(-128, 127, size)
|
|
||||||
grid = np.meshgrid(L, a, b, indexing='ij')
|
|
||||||
|
|
||||||
# Reshape the grid for KNN
|
|
||||||
grid_points = np.vstack([g.ravel() for g in grid]).T
|
|
||||||
|
|
||||||
# Use KNN to find the nearest neighbor for each grid point
|
|
||||||
from sklearn.neighbors import NearestNeighbors
|
|
||||||
nn = NearestNeighbors(n_neighbors=1, algorithm='auto').fit(original_colors)
|
|
||||||
_, indices = nn.kneighbors(grid_points)
|
|
||||||
|
|
||||||
# Create the LUT using the target colors
|
|
||||||
lut = target_colors[indices.ravel()]
|
|
||||||
lut = lut.reshape(size, size, size, 3)
|
|
||||||
|
|
||||||
# Apply Gaussian smoothing to create a more continuous mapping
|
|
||||||
for i in range(3):
|
|
||||||
lut[:,:,:,i] = gaussian_filter(lut[:,:,:,i], sigma=1)
|
|
||||||
|
|
||||||
return lut
|
|
||||||
|
|
||||||
def apply_lut(self, image_lab, lut, lut_size):
|
|
||||||
points = (
|
|
||||||
np.linspace(0, 100, lut_size), # L
|
|
||||||
np.linspace(-128, 127, lut_size), # a
|
|
||||||
np.linspace(-128, 127, lut_size) # b
|
|
||||||
)
|
|
||||||
|
|
||||||
xi = np.stack([
|
|
||||||
image_lab[..., 0],
|
|
||||||
image_lab[..., 1],
|
|
||||||
image_lab[..., 2]
|
|
||||||
], axis=-1)
|
|
||||||
|
|
||||||
remapped_lab = np.zeros_like(image_lab)
|
|
||||||
for i in range(3):
|
|
||||||
remapped_lab[..., i] = interpn(
|
|
||||||
points,
|
|
||||||
lut[..., i],
|
|
||||||
xi,
|
|
||||||
method='linear',
|
|
||||||
bounds_error=False,
|
|
||||||
fill_value=None
|
|
||||||
)
|
|
||||||
|
|
||||||
return remapped_lab
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"ColorPaletteExtractor": ColorPaletteExtractor,
|
|
||||||
"ImageBasedColorRemap": ImageBasedColorRemap,
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"ColorPaletteExtractor": "Extract Color Palette",
|
|
||||||
"ImageBasedColorRemap": "Image-Based Color Remap",
|
|
||||||
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue
Block a user