Density-Based Clustering For the Win!
Comparing three clustering algorithms on high-dimensional data
🚩 In this post: how to use clustering algorithms to detect cyberattacks on smart devices; the negative selection algorithm; k-means clustering; and DBSCAN.
1:Why does clustering detect hacks so well?
I’ll call the main idea ‘pattern-based threat detection.’ We have many examples of normal network requests and cyberattacks for smart devices. We find ‘patterns’ of normal network requests and compare these to incoming requests. 🔍 If the patterns and requests match, we predict the requests are normal. Otherwise, we block the requests as cyberattacks.
Mathematically, we measure 21 statistics about incoming network requests. Ex: how many bytes are transferred, how many packets are received, how many packets are sent, etc. Call these statistics
x1, x2, ..., x21.
We package these statistics in a vector
X. Then, clustering algorithms will identify the ‘pattern’ or ‘average statistics’ for normal datapoints:
Xavg. We can compare these to the statistics of incoming network requests
Xreq with any distance function:
If this distance is larger than some threshold
T, the incoming request is too far from ‘patterns’ of normal requests. We’d predict the incoming request is malicious. 👺 In practice, we save multiple ‘common patterns’, not just one vector
Xavg. Then, the incoming request has ‘multiple chances’ to be within one of the patterns’ thresholds.
Note: this technique requires very little processing power and memory, luckily for smart devices with small microprocessors. This is becaue only the distance is computed and a few pattern vectors are saved. Compare this to the millions of parameters and matrix multiplications in neural networks! 😵
There is one catch, though… how do we find these magical ‘patterns’ of normal network requests?
2:Negative Selection Algorithm
I started by deploying the negative selection algorithm. TL;DR: It generates random vectors as guesses for the common ‘patterns’. If the guesses are too far from any normal network request vector, then those random guesses are regenerated.
The issue with the negative selection algorithm was that my search space had bimodal peaks and outliers. Some of my features, like the duration of a network request could range from 0 to 100,000,000 milliseconds! Random numbers picked in that range are almost guaranteed to be far from some normal request data. 😕
Because of this messy data, the negative selection algorithm essentially made random guesses far from the normal network requests. This led to spiky performance changes.
The typical response here would be to just normalise my data so that it centred around a mean of 0 with a standard deviation of 1. Sadly, normalisation would make the algorithm less accurate with data distribution shifts.
In simple terms, if we train an algorithm on ‘transformed’ data, we also have to transform incoming network requests in real use.
On top of increased latency, the algorithm also has a new dependency. It has to use the mean and standard deviation calculated with the data seen in training.
However, this mean and standard deviation shifts over time in the real world. Ex: Users make new requests after a company starts using a new software. This often makes the algorithm more unpredictable/inaccurate. 😢
To see if I could avoid normalisation, I made the algorithm less sensitive. Earlier, each guess was regenerated if any normal network request was too far from it. I changed this to only regenerate guesses if:
Two guesses were too close to each other (doesn’t usefully ‘cover’ all datapoints)
A datapoint had no guess within some distance (adds guesses where needed).
This showed no changes in performance though… 😢😢
At this point, my modified negative selection algorithm reminded me of a k-means clustering algorithm with overly-random guesses. So I decided to switch tracks and transition to using a standard k-means clustering algorithm.
3:K-means Clustering (Round 1)
TL;DR of how this algorithm works:
Make random guesses about ‘patterns’ in network requests.
Find which datapoints are close to those patterns.
Update each guess to be the average of the datapoints it’s close to.
This sounds simple. While inspecting details, though, almost datapoints were assigned to just two random guesses. This meant the random first guesses were made poorly so that almost all of them were far from typical network requests.
Again, this issue came back to the outliers and extreme data distributions I showed above. So I was forced to clean data before I proceeded any further. 😤
4:Normalisation+Removing Outliers with DBSCAN
The simple step was to normalise all the data. Though this doesn’t fix bimodal distributions, it reduces the ‘weight’ of outliers.
Since so many statistics had a ‘long tail’ distribution, I decided to remove the few outliers at the ‘tail-end’ of the distributions. To do this, I used the DBSCAN algorithm (Density-Based Spatial Clustering of Applications with Noise). 🤓
The name is long, but the concept is simple:
Find datapoints in dense areas (lots of other datapoints around them)
Group datapoints in dense areas together, including datapoints near but not in the middle of dense areas.
All the far-away datapoints (not dense, not surrounded by others) are outliers.
Though simple in principle, my initial attempt to implement the algorithm in Python failed spectacularly. A 16-GB RAM, 8-core CPU running for half an hour barely made any progress. Luckily, after some dynamic programmic and cool set operation functions in Numpy, my optimised implementation worked in 7 seconds. 💪
Now, I could finally remove the outliers and get back to good old k-means clsutering!
5:K-means (Round 2, with K-means++)
Let me start with the results this time. After cleaning up the data with DBSCAN and normalisation, K-means performed much better than before!
What’s this ‘K-means++’ that I’m referring to? Recall that my first attempt at using K-means clustering failed because of poor initialisation of the first ‘guesses’ (AKA centroid locations). Almost all guesses were far from almost all datapoints. K-means++ is a method to choose initial guesses that ‘cover’ all the datapoints.
TL;DR - the algorithm starts by choosing a random guess. Then, the algorithm sets other guesses to be the same as datapoints (remember the datapoints and guesses are just vectors that are being initialised). The trick is that datapoints far away existing guesses are more likely to be chosen. 🧠
The graphs above were just screenshots of me trying to get algorithms to work. With that done, we’re going to find the best hyperparameters to run them with and test them on real IoT devices! Stay posted for cool updates with real-life hardware demos soon :-)