This post is an explanation of the k-means++ method and how to interpret results using silhouettes. I’ve wanted to write about this topic since the inception of this website because of its simplicity and elegance. Devised in 2007 by Arthur and Vassilvitskii, k-means++ solves the NP-hard problem of how to choose centroids in the naive k-means clustering problem and provides an answer that is provably close to the optimum clustering solution.
Data clustering is an important problem with many applications in science and engineering. The objective is to partition data into groups where each datum in a group is similar to other data in the same group, but different from data in other groups, by some measure of similarity. The typical motivation for using clustering techniques is to uncover latent structure in data, which can be used for knowledge discovery, dimensionality reduction, feature engineering, and a host of other applications.
k-means and its derivatives are one of the most commonly used clustering techniques in unsupervised learning tasks. The technique is widely used, because it’s simple to implement and scales reasonably well with input size. Given a dataset $X$, of observations $\{x_1,x_2,\dots,x_n\}$, and clusters $\{ c_1,c_2,\dots,c_k \}$, k-means attempts to learn each cluster label $c_i \in C_k$, $ \forall x_i \in X$ by minimizing an objective function— typically squared Euclidean distance. The goal is to minimize the intra-group variance in each cluster by summing the pairwise squared distances between each datum in a given cluster normalized by the total number of observations in that cluster. The sum of pairwise squared distances divided by the number of points in a given cluster is equivalent to the sum of the variances between each datum to the centroid. Thus, the position of each centroid is given by the multivariate mean, $\mu_k$ of the feature vectors in the cluster:
$$ \min_{c_1,\dots,c_k} \Bigg \{ \sum_{k=1}^{K} \sum_{i \in C_k} || x_i - \mu_k ||^2 \Bigg \} $$
Finding the optimal positions for the cluster centroids is NP-hard. There are almost $K^n$ ways to group $X_n$ samples into $K$ clusters, so even in $\mathbb{R}^2$ the k-means problem is intractable. However, in the 1950s Stuart Lloyd proposed a simple approach to the problem that converges to a locally optimum solution. Llyod’s algorithm uses an iterative descent technique of repeated assignment and update steps to converge on an approximate solution to the problem.^{1} The algorithm works as follows:
Convergence is usually reached within only a few iterations when the centroid assignments stop changing. The least-squares estimate is guaranteed to converge, but the objective function is non-convex, so the algorithm is unlikely to find the minimum sum of squares.
Despite its speed, Lloyd’s algorithm can yield arbitrarily bad clusters compared to the optimal solution. The randomness in initial centroid positioning can cause the minimization to terminate in particularly bad local optima, yielding inaccurate cluster labels. Consider a motivating example where the algorithm selects two points next to one another as two of the centroid positions:^{2}
How could the centroids be initialized to produce better results? One option would be to position the centroids at the extrema such that each centroid is as far away as possible from the other centroids. The problem with this approach is that it can also produce bad clusters because it’s sensitive to outliers in the data:
k-means++ initializes centroids by making a compromise between the random approach of Lloyd’s algorithm shown in Figure 2 and the extrema approach shown in Figure 3. Arthur and Vassilvitskii proposed selecting the first centroid position at random from the data, then initialize the remaining centroids by sampling probabilistically, proportional to the squared distance of the nearest centroid. The intuition of this idea is that the initialization should favor choosing centroids that are distal to other centroids but should also favor regions of high data density. This simple improvement to the initialization of the naive method allows the intractable k-means problem to be solved within an expected value of $\mathcal{O}(8 \log{} k + 2)$ of the optimal solution. Incredible!
k-means++ is frequently applied to high dimensional data, and the results of the algorithm can be difficult to interpret. A technique called silhouetting was developed in the 1980s to help analyze the results of such clustering tasks. The idea of silhouetting is to use a measure of cluster quality, called a silhouette coefficient. The coefficient is computed by calculating two terms—$a_i$ and $b_i$, which are measures of cluster cohesion and separation, respectively. Cohesion is measured as the mean similarity between the $i$th datum and all other elements belonging to the same cluster. Separation is measured as the mean dissimilarity between the $i$th datum and all data belonging to the most proximal neighboring cluster. The silhouette coefficient, $s_i$ for the $i$th datum is then defined as follows:
$$ \begin{equation} s_i= \begin{cases} 1 - a_i/b_i, && if \: a_i \lt b_i \\ b_i/a_i - 1, && if \: a_i \gt b_i \\ 0, && otherwise \\ \end{cases} \end{equation} $$
To construct a silhouette plot, the coefficients for each datum in a cluster are computed and then sorted by magnitude, and the width of each silhouette shows how many data elements belong to the cluster. When the silhouette coefficients are plotted, they produce a characteristic sawtooth pattern depicting the distribution of coefficients in the cluster as shown in Figure 1B.
The range of the silhouette coefficient is on the interval, $[-1, 1]$. Coefficients approaching 1 indicate data belonging to a cluster is well partitioned while coefficients around 0 indicate data may fall between multiple clusters. Observations with negative coefficients imply that the data is in the wrong cluster. Here’s an example where I’ve manually position the centroids to show how the silhouette coefficients highlight poor clusters:
Below is my implementation of the k-means++ method in Javascript. I wrote a distance function library called furlong to simplify some of the distance computations. The library is available on GitHub.
(function() { // distance func to use var distFunc = furlong.distance('euclidean') // return the object and value of the min // key in a array of objects var argmin = function(arr, key) { return arr.reduce(function(acc, el) { var thisVal = el[key] if (thisVal < acc.val) { acc.val = thisVal acc.datum = el } return acc }, {val: Infinity, datum: undefined}) } // inverse transform sampling var takeWhile = function(arr, pred) { return arr.filter(pred) } var inverseTransform = function(arr, key) { var u = Math.random() , fn = function(d) { return d[key] < u } return takeWhile(arr, fn).slice(-1)[0] } // return a random int from the domain [min, max) var randInt = function(min, max) { return Math.floor(Math.random() * (max - min + 1)) + min; } // return a random sample of indices from an array var randSample = function(arr, n) { var max = arr.length - 1 return d3.range(n).map(function() { return randInt(0, max) }) } // generate a cluster of gaussian data var genCluster = function(K, cov) { var n = 10000 / K , locX = randInt(10, 50) , locY = randInt(10, 50) , scale = 2; var genFuncs = cov.map(function(d) { var loc = randInt(10, 50) , scale = 2; return d3.random.normal(loc, scale) }) return d3.range(n).map(function() { return genFuncs.reduce(function(acc, f, i) { acc[cov[i]] = f() return acc; }, {}) }) } // generate k clusters of normally distributed data var genClusters = function(K, cov) { var data = [] d3.range(K).forEach(function() { data.push(genCluster(K, cov)) }) return [].concat.apply([], data).map(function(d, i) { var obj = cov.reduce(function(acc, el) { acc[el] = d[el]; return acc; }, {}) obj.idx = i; return obj; }) } // choose centroids with probability: D(x)^2) / \sum(D(x)^2 var choiceCentroid = function(arr, cov) { var distSum = d3.sum(arr, function(d) { return d.dist }) var cuprobs = arr.reduce(function(acc, el) { var prob = el.dist / distSum , cusum = acc.cusum + prob , datum = { idx: el.idx , dist: el.dist , prob: prob , cusum: cusum }; acc.data.push(keyTraverse(el, datum, cov)) acc.cusum = cusum return acc; }, {cusum: 0, data: []}) return inverseTransform(cuprobs.data, 'cusum') } // accessor function to retrieve values from objects var pluck = function(obj, keys) { return keys.map(function(d) { return obj[d] }) } // move data from a source object to a target object var keyTraverse = function(sobj, dobj, keys) { keys.map(function(key) { dobj[key] = sobj[key]; }) return dobj; } // seed the initial centroids for the kmpp algorithm var kmppSeed = function(data, K, cov) { var cData = d3.range(K).reduce(function(acc, clabel) { // choose 1 centoid at random from the data if (clabel === 0) { var centroid = data[randSample(data, 1)] // \forall data, excluding centroids, compute // D(x)^2, between x_i and the proximal centroid } else { var dists = data .filter(function(datum) { if (!(datum in acc.seenCentroids)) { return true } }) .map(function(datum) { var datumDists = acc.centroids.map(function(centroid) { var cVec = pluck(centroid, cov) , dVec = pluck(datum, cov) , dist = Math.pow(distFunc(cVec, dVec), 2) , props = {idx: datum.idx, dist: dist}; return keyTraverse(datum, props, cov) }) var minCentroidDist = argmin(datumDists, 'dist').datum return minCentroidDist }) .sort(function(a, b) { return a.dist - b.dist }) var centroid = choiceCentroid(dists, cov) } centroid.label = clabel acc.centroids.push(centroid) acc.seenCentroids[centroid.idx] = 0 return acc }, {centroids: [], seenCentroids: {}}) return cData.centroids } // \forall datum, \forall centroids, find the distance // to each point and assign the closest centroid var assignCentroidLabels = function(data, centroids, cov) { var converged = true var labeledPts = data.map(function(datum) { var prevLabel = datum.label var dists = centroids.map(function(centroid) { var cVec = pluck(centroid, cov) , dVec = pluck(datum, cov) , dist = Math.pow(distFunc(cVec, dVec), 2); return {dist: dist, label: centroid.label} }); var minCentroidLabel = argmin(dists, 'dist').datum.label var datumConverged = prevLabel === minCentroidLabel ? true : false; if (datumConverged === false) { converged = false } var props = {idx: datum.idx, label: minCentroidLabel} return keyTraverse(datum, props, cov) }) return {converged: converged, pts: labeledPts} } // after each datum has been assigned a label, // update the position of each centroid var repositionCentroids = function(data, centroids, cov) { var sortByCentroid = data.reduce(function(acc, el) { if (!(el.label in acc)) { acc[el.label] = [] } acc[el.label].push(el) return acc; }, {}) var centroidLabels = Object.keys(sortByCentroid) return centroidLabels.map(function(label) { var labelData = sortByCentroid[label] var datum = cov.reduce(function(acc, el) { acc[el] = d3.mean(labelData, function(d) { return d[el] }) return acc; }, {}) datum.label = +label return datum }) } // run the maximization phase of the algo var go = function(pts, centroids, cov) { // cache the intermediate positions of // the centroids during convergence var allCentroids = [centroids] // define the initial data structures var data = assignCentroidLabels(pts, centroids, cov) , pts = data.pts , converged = data.converged // when label assignment converges, // return the final centroid positions while (converged === false) { centroids = repositionCentroids(pts, centroids, cov) allCentroids.push(centroids) data = assignCentroidLabels(pts, centroids, cov) pts = data.pts converged = data.converged } return { centroids: centroids , pts: pts , allCentroids: allCentroids } } // top-level kmeans++ algo var kmpp = function(clusterData, K, cov) { var centroids = kmppSeed(clusterData, K, cov) , data = go(clusterData, centroids, cov) return data } // cluster with 5 dimensional input data var covariates = ['a', 'b', 'c', 'd', 'e'] , K = covariates.length , data = genClusters(K, covariates); console.log(kmpp(data, K, covariates)) }())
Lloyd’s method is a special case of expectation-maximization. ↩
Interactively move the centroids in Figure 1 to see how the cluster assignments are affected by centroid position. ↩