Two ways to understand overfitting (and don't peek!)

Assorted comments on overfitting

The very first time I was introduced to the notion of overfitting -- by three diagrams of regression curves on scatter plots labelled as you may guess -- I became very uncomfortable. I was probably eleven years old, and did not understand Bayesian statistics, did not realize that seeing 51 heads out of a hundred didn't imply that the coin genuinely had a heads rate of 0.51.

Or rather: I didn't realize that I did realize that. If you had made me bet on the number of heads that would come up in the next hundred flips, I would not provide odds that would indicate an honest belief of seeing 51 heads.

(Because deep down, I had a non-uniform prior.)

And the same principle applies to drawing regression curves. You may insist that your curve with all its squiggles is "unbiased" or give any one of the terms (one for every squiggle) used to describe non-Bayesian estimators, but ultimately -- that's not the curve you'll bet on. You know it's just far more likely that those squiggles are the result of noise which will not be the same (or knowably correlated) in the next sample, than for them to actually be a determining feature.

Yes, maybe the 79th coin toss will always be a tail because of a tiny AI hidden in the coin that counts, or maybe it was the result of factors that affected the 79th coin toss that you just didn't measure. While your data should affect your beliefs, they shouldn't completely overrule your priors.

And that's the key idea behind overfitting (and really the Bayesian notion of probability in general) -- how will your model, based on your data, of perform if you exposed it to data it hasn't yet seen. Because using it on data it hasn't seen is the point of your model -- that is your purpose in building it.

So overfitting occurs when a model learns features specific to your particular data set that don't generalize well. There are two ways this can occur:

  • The training set is a biased sample: E.g. MNIST digits are all centered (but suppose you're testing on non-centered digits), or your medical database is all from a particular country. 
  • There is noise: I mean, of course there's noise -- it's a statistical problem. Even if the world were deterministic, you still don't have all the information in the world. And you should avoid your model using this noise to make predictions, since noise is, by definition, unpredictable.
The first is a more tractable problem -- it can often be solved by data augmentation (if you have a very good picture of exactly how the data set is biased) or transfer learning (if there's a good chance the model is picking up on important features so you can just train it on the ones it missed). It's important to develop interesting transfer learning algorithms to solve this anyway, as the way humans learn often involves biased samples (e.g. personal experience) and reasoning capacity to unbias their knowledge. 

(Not that humans do always do this -- people often do form beliefs based on mere personal experience, but humans are capable of reasoning more clearly.)

The second is a problem that requires algorithmic solutions that hint our neural network towards Bayesian solutions. Solutions like "well, parameters are just a priori unlikely to be very large, so let's penalize that" (Lasso/Ridge regression), or more complicated (to explain in a simple Bayesian way) regularization algorithms like cross-validation and early stopping.

Another unclassified comment on overfitting: it's very easy to mistakenly "peek" at the test data. Simply in the act of saying "hey, this model works well on the test data, let's choose it" you are already performing a simple algorithm that checks several models and chooses them based on their performance on the test data -- i.e. you are kinda training the model on the test data, even if it is not seen in your code, just in your choice of hyperparameters.

That's why developing some kind of "theory" of hyperparameter optimization and regularization techniques is of importance, so you actually have a theoretical justification for picking your models. 

(Of course, this is hard. We've always been picking models, haven't we? For example when we decide to model something as belonging to a particular family of distributions so we only have to optimize in a 1-dimensional parameter space instead of the literal theory-space. In a sense, machine learning is the way to avoid modeling, due to the universal approximation theorem -- and the hope is that we can eventually make the priors as human-like as possible, which is the eventual goal of hyperparameter optimization and regularization algorithms.) 

No comments:

Post a Comment