ML NET with Kmeans Clustering

This is part two of my deep dive into ML.NET. My last post on using SCDAs to predict turnover ended up being a pretty good gateway into ML.NET with a relatively easy problem to “solve” and evaluate. Wanting to dive into a different algorithm, I chose k-means for clustering. In addition, Microsoft released version 0.9 of ML.NET. The API has changed slightly so I have updated the source code for the SCDA regression and communized library code to take advantage of the new API/syntax.

Being interested in security for years now and doing it daily at work, I wanted to switch gears to a security focus. A common problem in the ML world is once a threat is found, how do you classify it? Saying it is “Abnormal” or “Unsafe” as some of the industry conveys is pretty uninformative in my opinion. This is where clustering comes into play.

With k-means clustering, the idea behind the algorithm is to take a group of data based on a type and other features to create effectively a scatter plot. In my case, each cluster would be a threat category such as Trojan, PUA or generically Virus. In a production environment you would probably want to break it out further to Worms, Rootkits, Backdoors etc. But to keep it easy, I decided to keep it to just the three.

The next piece that I didn’t need to do for my last deep dive is to actually do feature extraction. Again to keep things easy, I decided to keep it to just PE32/PE32+ files and I just utilized the PeNet NuGet package to extract two features:

  1. Size of Raw Data (from the Image Section Header)
  2. Number of Imports (from the Image Resource Dictionary)
In a production model this would need considerable more features, especially when doing more granular classification.

Some Code Cleanup for 0.9

One of the big things I did the other night after updating to 0.9 was commonizing the code more and luckily the new APIs provided allows that. One of the biggest achievements was to get the Predict function 100% generic:

public static TK Predict(MLContext mlContext, string modelPath, T predictionData) where T : class where TK : class, new()
    ITransformer trainedModel;

    using (var stream = new FileStream(modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
        trainedModel = mlContext.Model.Load(stream);

    var predFunction = trainedModel.CreatePredictionEngine(mlContext);

    return predFunction.Predict(predictionData);

public static TK Predict(MLContext mlContext, string modelPath, string predictionFilePath) where T : class where TK : class, new()
    var data = File.ReadAllText(predictionFilePath);

    var predictionData = Newtonsoft.Json.JsonConvert.DeserializeObject(data);

    return Predict(mlContext, modelPath, predictionData);
For the clustering I took it a step further to allow either passing in the type of T or the file path for the JSON representation of T. The reason for this is typical tools for Threat Classification like ClamAV or VirusTotal provide the ability to just upload a file or scan it from a command line.

Another area of improvement was to standardize the command line arguments especially with future experiments on the horizon. An improved but not perfect change was to use an enum:

public enum MLOperations
And then in the Program.cs:

if (!Enum.TryParse(typeof(MLOperations), args[0], out var mlOperation))
    Console.WriteLine($"{args[0]} is an invalid argument");

    Console.WriteLine("Available Options:");

    Console.WriteLine(string.Join(", ", Enum.GetNames(typeof(MLOperations))));


switch (mlOperation)
    case MLOperations.train:
        TrainModel(mlContext, args[1], args[2]);
    case MLOperations.predict:
        var extraction = FeatureExtractFile(args[2], true);

        if (extraction == null)

        Console.WriteLine($"Predicting on {args[2]}:");

        var prediction = Predictor.Predict(mlContext, args[1], extraction);

    case MLOperations.featureextraction:
        FeatureExtraction(args[1], args[2]);
Utilizing the Enum allowed a quick sanity check of the first argument and then able to utilize a switch/case for each of the operations. For the next deep dive I will clean this up a bit more to probably just have an Interface or Abstract Class to implement for each experiment.

K-means Clustering

Very similarly to SCDAs the code to train a model was very easy:

private static void TrainModel(MLContext mlContext, string trainDataPath, string modelPath)
    var modelObject = Activator.CreateInstance();

    var textReader = mlContext.Data.CreateTextReader(columns: modelObject.ToColumns(), hasHeader: false, separatorChar: ',');

    var dataView = textReader.Read(trainDataPath);
    var pipeline = mlContext.Transforms
        .Concatenate(Constants.FEATURE_COLUMN_NAME, modelObject.ToColumnNames())
                clustersCount: Enum.GetNames(typeof(ThreatTypes)).Length));

    var trainedModel = pipeline.Fit(dataView);

    using (var fs = File.Create(modelPath))
        trainedModel.SaveTo(mlContext, fs);

    Console.WriteLine($"Saved model to {modelPath}");
With the new 0.9 API, the Text Reader has been cleaned up (in conjunction with the Extension Methods I created earlier). The critical piece to keep in mind is the clustersCount argument in the KMeans Trainer constructor. You want this number to equal the number of categories you have. To keep my code flexible since I’m using an Enum, I simply calculate the length. I strongly suggest following that path to avoid errors down the road. The rest of the code is generic (room for some refactoring in the next deep dive).

For my Threat Classification class it should look like a pretty normal class:

public class ThreatInformation
    public float NumberImports { get; set; }

    public float DataSizeInBytes { get; set; }

    public string Classification { get; set; }

    public override string ToString() => $"{Classification},{NumberImports},{DataSizeInBytes}";
I overrode ToString() for the FeatureExtraction, but otherwise pretty normal.

For my Prediction class it is a little different than with an SCDA:

public class ThreatPredictor
    public uint ThreatClusterId { get; set; }

    public float[] Distances { get; set; }
Where as the SCDA or other regression models return values, the k-means trainer returns the cluster that it found to be best fit. The Distances array contains the Euclidian distances from the data that gets passed in for prediction to the cluster centroids. For my case, I added a translation from the ClusterId -> a human readable string value (i.e. Trojan, Malware etc.).

Closing Thoughts

In training the data and running a model I was surprised at how quick it was to do both. Digging into the code on GitHub, everything looks to be as parallized as possible. Having used other technologies that aren’t multi-threaded – this was a refreshing sight. As for working with the clustering further, I think the big thing I will probably work on the scalable feature extraction and training in an efficient manner (right now it’s single threaded and loaded into memory all at once).
none on this post