3 ways to design effective classes in ML Classification Algorithms

Nitin Pasumarthy
7 min readJul 15, 2019

In this post we are going to see three ways to affectively design the target classes in classification problems using the properties of the training data alone.

Background

Classification is an area in Machine Learning (ML) where the machine is tasked to learn to categorize a given input. For example, given an image, the machine should return the category the image belongs to.

Figure 1: A machine learned function, f, which takes an image as input and returns if it’s a “cat” or a “dog”

As shown in Figure 1 f is trained with labelled training images of cats and dogs. Here, “dog” is a category and “cat” is another category. These categories are also referred to as “classes” in classification problems. The goodness of f largely depends on the quality of data it is trained on — if we feed images of monkey labelled as “cat”, that is what it will learn!

Figure 2: Feeding garbage data while training, leads to garbage predictions making the entire process futile

Quality of training data implies the quality of input, classes and their mappings. In this post we will focus on some ideas for designing “classes” in some non-trivial situations. In an earlier post, I shared some ideas for handling high cardinal input data with hashing, character one-hot encoding and embeddings.

Class imbalance is one of the first things to look out for when generating training data for classification problems. If our training data has 98 examples of dog and 2 examples of cat, f will have a tough time learning how a cat actually looks like as it is under represented. Can you guess another thing to be watchful of here? Here’s a hint, what’s the accuracy of f below for the above dataset,

def f(image):
return "dog"

It is 98%, as f correctly predicts 98 of 100 training images as “dog”, but is absolutely useless. It just got lucky because of the bias in training data. We should always remember that,

ML is a daemon which tries to solve the problem in the worst possible way that you didn’t prohibit

This statement by John Mount really stuck with me. So it is our responsibility to either balance the classes or use a different accuracy metric while designing a classification algorithm. All the techniques here, ensure this problem is addressed.

Scenario 1: Percentile Buckets

This idea is useful when we have a continuous, long tail distribution of numbers, say item prices on an e-commerce website, and we wish to classify them as low, medium and high using the attributes of the items.

f( item details ) = price class

In order to prevent class imbalance, we can divide the prices into equal sized buckets, such that each bucket has approximately equal number of items. However, we have 2 problems here,

  1. The definitions of low, medium & high could be such that most of the items are in medium class, again leading to class imbalance!
  2. Also these definitions might change frequently forcing us to retrain f with new class boundaries each time.

So we need a bucketing scheme which is flexible enough to address both the problems.

Solution: What if we split the price column into fine percentile buckets, then we could report the buckets merged on the basis of the current definitions of low, medium & high price class.

Figure 3: Code showing how a long tail item price distribution is bucketed to solve the class imbalance problem

As shown in Figure 3, the price column in bucketed into 10 equal sized buckets though we only need 3 — low, medium & high. Now if we are given,

  • < ₹233 as low,
  • ₹234 to ₹685 as medium, and
  • > ₹685 as high

then we can merge them to [0–3], [4–6] and [7–9] buckets respectively. At model training and monitoring time, we can report the accuracies of these combined buckets as that is what we care about,

Figure 4: A custom tensorflow keras accuracy metric generated by merging percentile buckets

The code in Figure 4 is an optional advanced setup which helps with monitoring. But it can also give an idea on how custom model performance metrics can be defined in tensorflow keras, and how the same could be used to measure the accuracy of percentile buckets. three_class_acc function grades the predictions made by the model as if there were only 3 classes. This custom metric can be visualized on Tensorboard along with the default 10 classes accuracy metric. This way, we train and maintain a single model and use it for multiple use-cases.

Figure 5: Overview of Percentile Buckets idea in general applicable to any continuous distribution target class

Scenario 2: Hierarchical Classes

This idea comes in handy when you have a hierarchy of classes at your disposal, and the classification function is free to predict any class in the lineage.

Say we would like to build a classification model which takes a customer bug report as input and predicts which employee of the organization would be the best fit to solve it. Historical manually triaged bugs could be a good training dataset. Something like,

f( bug report ) = employee id

Such a formulation could have two problems,

  1. If every employee is their own class, there will be too many classes and the model would not see enough training samples for each class
  2. Skew in number of training samples per class, as new & rapidly evolving code bases are prone to more bugs, as compared to legacy stable ones.

One way to solve the problem without obtaining additional training data is by using organization chart. Instead of predicting individual employees, we combine them under a single director or VP to arrive at a reasonable number of classes.

Figure 6: Organization chart of a hypothetical company with target classes highlighted in red

As shown in the Figure 6, highlighted Senior Directors and VPs could be a possible set of target classes based on number bugs (training samples) under each. These bugs can then be narrowed down to the individual engineer either manually or by using some heuristics.

Figure 7: Code showing how multiple splits and merges to a hierarchy of classes lead to a more balanced set of target classes for a classification problem

Rendered notebook in Figure 7 shows the class explosion problem when the best fit employees who can solve the bug is used as target classes. On the flip side, if VPs are used as target classes, it could lead to high imbalance. It also shows how Standard Deviation (std) of class sizes is a suitable metric for quantifying class imbalance. The split function splits the biggest VP classes into his / her direct reports leading to a more balanced set. Similarly, merge function combines the smallest Senior Directors classes whose count is less than a given threshold into the VPs they report to. This is repeated multiple times to arrive at a std value that is 5.82 times smaller than the original value.

Figure 8: What we achieved after by split & merge functions to balance the initial set of 4 highly imbalanced classes. X-axis is the employee ID (prospective class) and y-axis is number of training samples in each class. Note that, the y-axis ranges are different in “Before” & “After”

Scenario 3: Balanced Clustering

This idea is useful when the required number of target classes is known but the actual classes are not clearly defined.

Clustering is another sub-field in ML, besides classification, which tries to identify structure & patterns in an unlabeled dataset. Therefore, unlike in Classification, there is no right answer for clustering algorithms. Based on the problem, we choose the appropriate clustering algorithm and use the clusters it identifies. For example, in this paper, the authors cluster taxi drop-off points to combine close-by (latitude & longitude pairs) points as a virtual region and use it as a target class. A classification algorithm then uses the 🚖 pick-up points as input and learns to predict the drop-off virtual region. This way they reduce the number of target classes for the classification algorithm.

Here we will see how clustering can help generate balanced classes in graphs / networks datasets. This idea is applicable in general to any unlabelled dataset where the problem is loosely defined. Say we have a social network dataset (like LinkedIn or Facebook) where nodes are users and edges are connections (e.g. “friends” or “share a common interest”) among them. Now when a new user, Alice, signs up, it is hard to generate any relevant recommendations for her, as we don’t have any historical information about her. This is referred to as cold-start problem. One way to solve this is to find a group of users similar to Alice using her sign up information and use their recommendations for her until we get to know her well.

Figure 9: How clustering and classification can work together to address cold-start new user problem

As shown in Figure 9, a way to solve the cold-start user problem is two fold,

  1. Group similar users together — Clustering
  2. Using the new user’s attributes (from the sign up process), find their closest group — Classification

As clustering identifies the target classes for the classification algorithm, we focus here on how to generate balanced clusters to prevent class imbalance problem.

Figure 10: Using Spectral Clustering algorithm we can identify the right number of balanced clusters

As shown in Figure 10, spectral clustering algorithm was used to identify balanced clusters in a toy social network. These clustering algorithms usually expect us to provide the number of clusters which can be experimented with multiple values if the graph is small. Else this may need some domain intuition or experiments to be run on a good representative sample of the graph.

Summary

Classification is an ML algorithm which works well in many scenarios when either labelled data is available or can be generated using the problem statement itself, as seen in Scenario 3. We saw 3 techniques for designing the target classes for the classification algorithms which take care of common problems like class imbalance on limited data, and unavailability of labelled data. Do you have any other interesting ways to design the y in classification algorithms? Do share them in the comments.

Special thanks to Atasi Panda for reviewing my content and making it more enjoyable.

--

--