Logistic regression is a classification algorithm traditionally limited to only two-class classification problems.
If you have more than two classes then Linear Discriminant Analysis is the preferred linear classification technique.
In this post you will discover the Linear Discriminant Analysis (LDA) algorithm for classification predictive modeling problems. After reading this post you will know:
- The limitations of logistic regression and the need for linear discriminant analysis.
- The representation of the model that is learned from data and can be saved to file.
- How the model is estimated from your data.
- How to make predictions from a learned LDA model.
- How to prepare your data to get the most from the LDA model.
This post is intended for developers interested in applied machine learning, how the models work and how to use them well. As such no background in statistics or linear algebra is required, although it does help if you know about the mean and variance of a distribution.
LDA is a simple model in both preparation and application. There is some interesting statistics behind how the model is setup and how the prediction equation is derived, but is not covered in this post.
Let’s get started.
Limitations of Logistic Regression
Logistic regression is a simple and powerful linear classification algorithm. It also has limitations that suggest at the need for alternate linear classification algorithms.
- Two-Class Problems. Logistic regression is intended for two-class or binary classification problems. It can be extended for multi-class classification, but is rarely used for this purpose.
- Unstable With Well Separated Classes. Logistic regression can become unstable when the classes are well separated.
- Unstable With Few Examples. Logistic regression can become unstable when there are few examples from which to estimate the parameters.
Linear Discriminant Analysis does address each of these points and is the go-to linear method for multi-class classification problems. Even with binary-classification problems, it is a good idea to try both logistic regression and linear discriminant analysis.
Representation of LDA Models
The representation of LDA is straight forward.
It consists of statistical properties of your data, calculated for each class. For a single input variable (x) this is the mean and the variance of the variable for each class. For multiple variables, this is the same properties calculated over the multivariate Gaussian, namely the means and the covariance matrix.
These statistical properties are estimated from your data and plug into the LDA equation to make predictions. These are the model values that you would save to file for your model.
Let’s look at how these parameters are estimated.
Learning LDA Models
LDA makes some simplifying assumptions about your data:
- That your data is Gaussian, that each variable is is shaped like a bell curve when plotted.
- That each attribute has the same variance, that values of each variable vary around the mean by the same amount on average.
With these assumptions, the LDA model estimates the mean and variance from your data for each class. It is easy to think about this in the univariate (single input variable) case with two classes.
The mean (mu) value of each input (x) for each class (k) can be estimated in the normal way by dividing the sum of values by the total number of values.
muk = 1/nk * sum(x)
Where muk is the mean value of x for the class k, nk is the number of instances with class k. The variance is calculated across all classes as the average squared difference of each value from the mean.
sigma^2 = 1 / (n-K) * sum((x – mu)^2)
Where sigma^2 is the variance across all inputs (x), n is the number of instances, K is the number of classes and mu is the mean for input x.
Making Predictions with LDA
LDA makes predictions by estimating the probability that a new set of inputs belongs to each class. The class that gets the highest probability is the output class and a prediction is made.
The model uses Bayes Theorem to estimate the probabilities. Briefly Bayes’ Theorem can be used to estimate the probability of the output class (k) given the input (x) using the probability of each class and the probability of the data belonging to each class:
P(Y=x|X=x) = (PIk * fk(x)) / sum(PIl * fl(x))
Where PIk refers to the base probability of each class (k) observed in your training data (e.g. 0.5 for a 50-50 split in a two class problem). In Bayes’ Theorem this is called the prior probability.
PIk = nk/n
The f(x) above is the estimated probability of x belonging to the class. A Gaussian distribution function is used for f(x). Plugging the Gaussian into the above equation and simplifying we end up with the equation below. This is called a discriminate function and the class is calculated as having the largest value will be the output classification (y):
Dk(x) = x * (muk/siga^2) – (muk^2/(2*sigma^2)) + ln(PIk)
Dk(x) is the discriminate function for class k given input x, the muk, sigma^2 and PIk are all estimated from your data.
How to Prepare Data for LDA
This section lists some suggestions you may consider when preparing your data for use with LDA.
- Classification Problems. This might go without saying, but LDA is intended for classification problems where the output variable is categorical. LDA supports both binary and multi-class classification.
- Gaussian Distribution. The standard implementation of the model assumes a Gaussian distribution of the input variables. Consider reviewing the univariate distributions of each attribute and using transforms to make them more Gaussian-looking (e.g. log and root for exponential distributions and Box-Cox for skewed distributions).
- Remove Outliers. Consider removing outliers from your data. These can skew the basic statistics used to separate classes in LDA such the mean and the standard deviation.
- Same Variance. LDA assumes that each input variable has the same variance. It is almost always a good idea to standardize your data before using LDA so that it has a mean of 0 and a standard deviation of 1.
Extensions to LDA
Linear Discriminant Analysis is a simple and effective method for classification. Because it is simple and so well understood, there are many extensions and variations to the method. Some popular extensions include:
- Quadratic Discriminant Analysis (QDA): Each class uses its own estimate of variance (or covariance when there are multiple input variables).
- Flexible Discriminant Analysis (FDA): Where non-linear combinations of inputs is used such as splines.
- Regularized Discriminant Analysis (RDA): Introduces regularization into the estimate of the variance (actually covariance), moderating the influence of different variables on LDA.
The original development was called the Linear Discriminant or Fisher’s Discriminant Analysis. The multi-class version was referred to Multiple Discriminant Analysis. These are all simply referred to as Linear Discriminant Analysis now.
This section provides some additional resources if you are looking to go deeper. I have to credit the book An Introduction to Statistical Learning: with Applications in R, some description and the notation in this post was taken from this text, it’s excellent.
Get your FREE Algorithms Mind Map
I’ve created a handy mind map of 60+ algorithms organized by type.
Download it, print it and use it.
Also get exclusive access to the machine learning algorithms email mini-course.
In this post you discovered Linear Discriminant Analysis for classification predictive modeling problems. You learned:
- The model representation for LDA and what is actually distinct about a learned model.
- How the parameters of the LDA model can be estimated from training data.
- How the model can be used to make predictions on new data.
- How to prepare your data to get the most from the method.
Do you have any questions about this post?
Leave a comment and ask, I will do my best to answer.
Need Help Getting Past The Math?
Finally understand how machine learning algorithms work, step-by-step in the new Ebook:
Master Machine Learning Algorithms
Take the next step with 12 self-study tutorials across
10 top machine learning algorithms.
Includes spreadsheets that show exactly how everything is calculated.
Ideal for beginners with no math background.
Pull Back the Curtain on Machine Learning Algorithms