Out-of-distribution (OOD) for ML practitioners
In the ideal world, all data is made of normal distributions, and I can sample uniform randomly from the population for my training data. Unfortunately, that is almost never the case.
Out-of-distribution (OOD) in machine learning happens when you have to make a prediction on an input that is drawn from a different distribution than your training dataset. Technically any sample in your test time can be considered OOD, since it wasn’t used in training. However, this becomes more problematic when there is a systematic difference between your training dataset and your test dataset. A common reason is a bias in training data collection, e.g., a picture with a flower you used in your training dataset may be more obviously featuring a flower and in focus, while a picture provided by some user on the internet may be cropped and low in resolution. In many cases, it is too costly or simply impossible to collect labels on your test dataset, no matter how much you intend to correct for these kinds of biases. For example, I won’t be able to collect data points from tomorrow, which is when my model should be deployed. After all, if you could easily collect true labels of your test target, why would you have wanted to use a model in the first place?
OOD is an inevitability, so what should we do about it? Understanding this at a little deeper level helps with addressing it more productively.
How is this manifested?
Let’s define your model as a function of some input \(X\),
\[y = f(X)\], where \(y\) aims to match \(P(Y)\). This can be decomposed into
\[P(Y) = P(Y \vert X) P(X).\]In other words, your \(f\) not only depends on the conditional probability of \(Y\) given \(X\), but also the distribution of \(X\). When a data shift occurs between training and test time, it can happen in two different ways:
-
change in \(P(X)\)
This is called covariate shift, just a fancy name for the distribution change in \(X\). For example, your training data mostly has sephia-tinted images, but the test data has images with more diverse colors. The distributions of the color in these two datasets are different.
-
change in \(P(Y\vert X)\)
When I was young living in Korea, tall 1-liter cartons were almost always packagings for milk. These days, I see water packaged in the boxes of the same shape. The probability of that tall classic carton being milk has changed. This is a concept shift.
this carton could have water now
Understanding which of these contributes more to your model helps, because you would address them differently. Let’s look at each of these a little bit more in depth and see what we can do about them.
1. Covariate Shift
Imagine that you only had a single feature in the model. It’d be easy to discover covariate shift. If your distribution is normal-esque, do some t-testing. In the real world, though, you most likely have multiple features that have non-zero covariance with each other, and their marginal distributions are not normal-esque. Now doing t-testing becomes much less feasible.
In practice, I found checking the marginal distribution of each feature to be quite helpful. Simple KL-divergence on categorical features and sensibly-bucketized continuous features has given an informative first-pass measure of this type of shift.
Once you know for sure that covariate shift is happening, you might be tempted to jump into model techniques like self-supervised pre-training, because that sounds fun and smart. But the most important and likely most impactful treatment is to understand why this is happening and fixing the shift in \(X_{\text{train}}\) for the supervised learning setting.
First look into the data. Could it be because there is a bias in collecting the training data? Is there a way to get labels on a more diverse set of \(X\)? If data-level intervention were not enough, would it be feasible to build a reward function that provides \(Y\) s for all \(X\) s during the training time? Or maybe we actually should solve a related problem where labels are more readily available. For example, user satisfaction score out of 5 is difficult to measure for all users, but whether they renew their subscription after some months is measurable for all.
If you have already addressed the problems at the data level and are looking for more, or just wanted to explore solution at the model level, here are some starter suggestions:
- come up with a pre-training task of self-supervised learning that relies only on \(X\) and no \(Y\), so that the model learns from all of \(X\), rather than \(X\) that has label,
- any domain adaptation-related technique,
- importance-weigh examples where \(X_{\text{train}}\) is more similar to \(X_{\text{test}}\) during the supervised learning.
While both the data- and model-centric approaches are valid, it is always helpful to explicitly list the causes and estimate their impact, because if you realized there is a systematic and yet fixable bias in data collection, no model approach is likely more fruitful than simply addressing it directly.
2. Concept shift
A lot of real-world cases of concept shift is temporal/geographical, meaning you can observe this in the datasets collected with the “same” method from different time periods or areas. For example, if my model were to guess the fed funds rate \(Y\) based on the inflation rate \(X\), it should return something very different between 1960s vs 2020s. To discover whether I have a concept shift, practically speaking, I’d first check if there is a covariate shift and try to get rid of it. If I achieved some amount of \(P(X)\) parity between training and test, then I can start attributing the gap in \(y\) to some concept shift and think about why this is happening. In my fed fund rate prediction model example, I could train two models with data from different eras, give the same input, and observe how the their outputs differ. Obviously this is not very theoretically foolproof, because \(f\) is not a perfect representation of \(P(Y \vert X)\), but if I wanted to quanitfy the concept shift, I found a method like this to be a practical proxy measure.
Notice what I left out in my model example here: the time at which the inflation rate is from. Often, introducing new features that can explain such a shift is the best thing to do, e.g., cyclical pattern, trend, any feature with more granular information. If the cause is indeed cyclical, introducing a relevant feature will be relatively easy. However, if it’s trend, which may not always have the same momentum and direction as they did in the past, it will be more challenging. If the shift is temporal in nature, in addition to introducing new features, you could train the model more frequently or do online learning. Sometimes the best thing you can do is to provide the hindsight as quickly as possible.
Final Thoughts
Whenever I read papers, I am amazed by new techniques that people have tried and worked. But in my day-to-day life in improving my own ML problems, trying out these new techniques with the biggest benchmark improvement does not automatically give me the biggest bang for the buck. Rather, understanding the current problem in depth, such as with the methods I talked about above, is tremendously helpful in navigating my next set of experiments to try.
But this is not to say reading papers discussing the state-of-the-art(SOTA) techniques doesn’t help. On the contrary, it is immensely helpful, not necessarily because they give me a new experiment idea to try, but more often because it is trying to expose a new problem; what the current SOTA fails to address, and what “style” of approaches they are taking to solve it. That inspires me a new way to parse my own problem. The internet loves to talk about LLM and vision, where the features are relatively obvious and benchmarks are plenty. In reality, so much hinges on just the data.