Blog

Sep 20, 2021 · post

How (and when) to enable early stopping for Gensim's Word2Vec

The Gensim library is a staple of the NLP stack. While it primarily focuses on topic modeling and similarity for documents, it also supports several word embedding algorithms, including what is likely the best-known implementation of Word2Vec.

Word embedding models like Word2Vec use unlabeled data to learn vector representations for each token in a corpus. These embeddings can then be used as features in myriad downstream tasks such as classification, clustering, or recommendation systems. Gensim’s documentation and tutorials are sufficient for straightforward applications, e.g., training on a corpus of documents composed of plain text. But they don’t cover what to do if your use case is more complex, like how to choose the number of training epochs or other hyperparameter values. This is especially crucial if you’re trying to learn vector representations for non-natural language tokens, e.g., learning embeddings for products or users or books or music. And what does early stopping have to do with any of this?

In this post, we’ll first cover why and when you should and shouldn’t do early stopping with Gensim’s Word2Vec, and we’ll finish up with how to do it with code. Because we’re no fun, let’s start with the “don’t” before getting to the “do"s.

Don’t use early stopping to prevent overfitting

What do we mean when we say “early stopping”? Early stopping was developed as a regularization technique to prevent model overfitting. One would monitor the prediction error on both the training set and the validation set during training. As training progressed, both types of errors would reduce, but eventually, the validation error would start to increase. This is a classic instance of overfitting and one solution was early stopping — terminating training at the point when the validation error reached its minimum.

Image credit: paperswithcode

While this is a perfectly valid technique for all kinds of models, Word2Vec isn’t one of them for a couple of reasons:

  1. Training loss in Word2Vec doesn’t mean a whole lot. It’s buggy and misleading; even the Gensim maintainers know this! It’s been on their To-Do list of open Issues for years (and we don’t begrudge them that — it’s a small core team doing good work!).
  2. Even once that issue is resolved, the final training loss has little to do with how well the resulting embeddings will perform on downstream tasks. For example, let’s suppose we trained a Word2Vec model on a small dataset and noted its final training loss. We then trained a second model using the exact same data, but this time we increased the size of the embedding vectors. Model 2 now has more dimensions to “learn.” However, since our dataset was very small, this larger model would simply resort to memorization. This would yield a lower training loss, but would ultimately result in worse embeddings that aren’t able to generalize to unseen data!

The moral of the story is that early stopping for Word2Vec should not be done to prevent overfitting. The concept of overfitting is difficult, if not impossible, to gauge by considering the training loss alone. If overfitting is your concern, embeddings should be assessed against the end application, e.g., the classification, clustering, or recommendation task you ultimately want to use your embeddings for. Experimenting with various vector sizes for a fixed amount of data will do more to prevent overfitting than early stopping.

With that out of the way, let’s look at times when early stopping Word2Vec does make sense.

Do use early stopping to save compute

Instead of preventing overfitting, there are times when we simply wish to interrupt Word2Vec’s training loop before it completes the specified number of epochs in order to conserve computational resources. We thought of two situations where this could come in handy: during hyperparameter optimization, and as a way to train for a sufficiently large number of epochs without having to include epochs as a hyperparameter.

Hyperparameter tuning

Word2Vec has a host of hyperparameters, from the embedding vector size to the learning rate, to more esoteric quantities like the context window size and negative sampling exponent. The defaults stored in Gensim’s Word2Vec implementation come directly from the academic literature in which the authors empirically determined hyperparameter values that work well for most natural language tasks.

But what if you aren’t learning embeddings for word tokens? In the paper Word2Vec applied to Recommendations: Hyperparameters Matter, the authors use Word2Vec to learn embeddings for all kinds of entities: songs in a music queue, articles browsed on a news website, and products purchased on an e-commerce store. They demonstrate that tuning Word2Vec is crucial for achieving useful embeddings.

But hyperparameter optimization is expensive, especially on very large corpora. And Gensim’s implementation, while fast, is not implemented for GPUs, so you’ll want to make the most of the CPUs at your disposal. One way to do this is to identify poorly-performing hyperparameter configurations during the optimization phase and terminate them early.

Here we show the results of early stopping during hyperparameter tuning for Word2Vec trained on product IDs for an e-commerce use case. Each colored curve represents a different hyperparameter configuration. We evaluate the embeddings we’ve learned up to that epoch on a metric appropriate for our downstream task (Recall@10 for providing recommendations). Using sophisticated scheduling algorithms provided by the Ray Tune library (more on this below), we can automatically terminate underperforming configurations in a principled fashion, thus saving loads of compute time: of the fifteen attempted configurations, only six were trained for a full fifty epochs.

Avoid tuning the number of epochs

In the figure above we saw that the top hyperparameter configurations were each trained for fifty epochs. But how do you know how many epochs to train for?

This may seem like a simple question but it quickly becomes trickier when you consider that Word2Vec was designed to train with a learning rate that decays over the course of training. This is all worked out for you behind the scenes in the implementation details, but it means that you absolutely should not try to design your own training loop to control the number of epochs. (Related: see this StackOverflow post in which one of the Gensim maintainers illustrates why you shouldn’t do this.)

One option is to include epochs as a tunable hyperparameter during optimization, but that can quickly explode the number of configurations to validate. Thankfully, there’s another way.

We need to provide the model with a guesstimate of the number of epochs: too few and the model will underfit; too many and nothing really bad will happen — the model’s embeddings will converge and additional training won’t have much effect. Therefore, it’s safer to err on the side of more epochs. However, there are computational costs associated with this choice. Training for many more epochs than is necessary to achieve good results is wasteful and time-consuming. It would be better to train for exactly what you need. But since you don’t know that a priori, this is a time when you might enter in a very large number of epochs and then consider early stopping.

How do we know when to stop? Our goal is to achieve quality embeddings so we can monitor those embeddings on a downstream task and stop training once those metrics plateau. Refer again to the figure above. Only seven hyperparameter configurations moved past the first round of early stopping, and most of those had Recall@10 scores that plateaued after about thirty epochs of training. There may still be some additional upward trend in the blue, purple, and pink runs, but the orange, red and grey look quite flat. It’s likely that many of these runs have trained such that the embeddings are about as good as they’re going to get. We thus could have saved additional compute time by terminating training runs in which our desired metric has plateaued.

Of course, the drawback with this approach is that it only works if you have a downstream task in mind and appropriate metrics that can evaluate embeddings against that task.

Time for code!

As of Gensim 4.0, *2Vec models do not have an early stopping feature. While there has been discussion of including such functionality in the future (see this Issue), it’s not currently on the road map. We’ll need help from another library: Ray Tune. Ray Tune is a Python library for experiment execution and hyperparameter tuning at scale. It also supports state-of-the-art scheduling algorithms for efficiently handling hyperparameter optimization. (Interested in HPO? Check out our deep dive on the subject here!)

Callbacks

In order for Ray Tune to monitor our Word2Vec model and perform early stopping for us, we need to provide it with information on how the model is performing during training. For this, we’ll take advantage of Gensim callbacks. The examples on the Gensim documentation depict how to report information about the model or the training itself (the current epoch, for example). But we need something more sophisticated: we need to monitor how the embeddings we’ve learned so far are performing on a downstream task. This means we need to access the KeyedVectors method of the model during training. This is actually a bit tricky because we need to simulate training completion during which the embeddings are normalized for downstream use. To do this, we make a deepcopy of the model (so that the original model can continue training in the next epoch).

import copy
from gensim.models.callbacks import CallbackAny2Vec

class RecallAtKLogger(CallbackAny2Vec):
    '''Report Recall@K metric at end of each epoch
    
    Computes and reports Recall@K on a validation set with
    a given value of k (number of recommendations to generate). 
    '''
    def __init__(self, validation, k=10, ray=False):
        self.epoch = 0
        self.validation = validation
        self.k = k
        self.ray = ray

    def on_epoch_end(self, model):
        # make deepcopy of the model and emulate training completion
        mod = copy.deepcopy(model)
        mod._clear_post_train() 
        
        # compute the metric we care about on a recommendation task
			  # with the validation set using the model's embedding vectors
        score = 0
        for query_item, ground_truth in self.validation:
            try:
                # get the k most similar items to the query item
                neighbors = mod.wv.similar_by_vector(query_item, topn=self.k)
            except KeyError:
                pass
            else:
                recommendations = [item for item, distance in neighbors]
                if ground_truth in recommendations:
                    score += 1
        score /= len(self.validation)        
        
        if self.ray:
            tune.report(recall_at_k = score) 
        else:
            print(f"Epoch {self.epoch} -- Recall@{self.k}: {score}")
        self.epoch += 1

Early Stopping with Ray Tune

Once we have our Gensim Callback, early stopping with Ray Tune is easy!

from gensim.models.word2vec import Word2Vec
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.stopper import TrialPlateauStopper

# this helper function is what Tune uses to optimize hyperparameters
def tune_w2v(hyperparameters:dict):
    """ Hyperparameter optimization wrapper for Ray Tune"""
		# instantiate our callback logger
    ratk_logger = RecallAtKLogger(valid, k=10, ray_tune=True) 
		# instantiate our model and pass our callback and hyperparameters
    model = Word2Vec(sentences=train, callbacks=[ratk_logger], **hyperparameters)

# define the search space of hyperparameters (we're performing a random search)
# we conclude the search space with some fixed hyperparameters that we want 
# passed to the model every time, including the number of epochs ("iter")
search_space = {
    "size":         tune.randint(10, 100),
    "window":       tune.randint(3, 25),
    "ns_exponent":  tune.quniform(-1.0, 1.0, .2),
    "alpha":        tune.loguniform(1e-4, 1e-2),
    "negative":     tune.randint(1, 25),
    "iter": 50,     # number of epochs
    "min_count": 1, # number of instances a token appears before word2vec creates an embedding vector for it
    "workers": 6,   # number of CPU workers
    "sg": 1,        # trains the skip-gram version of word2vec
}

# ASHA terminates under-performing hyperparameter configurations in a principled way
asha_scheduler = ASHAScheduler(max_t=100, grace_period=10) 

# terminates training sessions in which the monitored metric has plateaued
stopping_criterion = TrialPlateauStopper(metric='recall_at_k', std=0.002) 

# Let Ray Tune optimize!
analysis = tune.run(
    tune_w2v,
    metric="recall_at_k",
    mode="max",
    scheduler=asha_scheduler,
    stop=stopping_criterion,
    verbose=1,
    num_samples=15,       # number of trials to attempt
    config=search_space
)

Look at the results now! We’ve plotted them on the same scale as the figure above and the difference is striking. Not a single training run ran for the full 50 epochs! Some were stopped by the ASHA Scheduler due to suboptimal hyperparameter values; others were stopped due to a plateau in the downstream metric we were monitoring. And if we examine the best performing model in both experiments, each yields embeddings with roughly the same amount of predictive performance for our downstream recommendation task, indicating that we didn’t have to sacrifice quality for speed-up.

There are, of course, many caveats, mostly due to randomness in these experiments. The hyperparameter optimization we perform uses a random search so the resulting models are certainly not identical. Even if we fixed the hyperparameter values (say, via a grid search), there would still be minor differences since Word2Vec is randomly initialized before training.

These methods also rely on a crucial requirement: that you have a downstream task in mind and have chosen a metric suitable for measuring its success. We demonstrated that through a simple recommendations application: we learned product embeddings and then evaluated how successful they were against a set of ground truth user data.

Conclusion

We did it! We performed two kinds of early stopping on Gensim’s Word2Vec in a principled way using the algorithms provided by the Ray Tune hyperparameter optimization library. What embeddings are you going to train next?

Read more

Newer
Sep 21, 2021 · post
Older
Jul 7, 2021 · post

Latest posts

Sep 22, 2021 · post

Automatic Summarization from TextRank to Transformers

by Melanie Beck · Automatic summarization is a task in which a machine distills a large amount of data into a subset (the summary) that retains the most relevant and important information from the whole. While traditionally applied to text, automatic summarization can include other formats such as images or audio. In this article we’ll cover the main approaches to automatic text summarization, talk about what makes for a good summary, and introduce Summarize. – a summarization prototype we built that showcases several automatic summarization techniques.
...read more
Sep 21, 2021 · post

Extractive Summarization with Sentence-BERT

by Victor Dibia · In extractive summarization, the task is to identify a subset of text (e.g., sentences) from a document that can then be assembled into a summary. Overall, we can treat extractive summarization as a recommendation problem. That is, given a query, recommend a set of sentences that are relevant. The query here is the document, relevance is a measure of whether a given sentence belongs in the document summary. How we go about obtaining this measure of relevance varies (a common dilemma for any recommendation system).
...read more
Sep 20, 2021 · post

How (and when) to enable early stopping for Gensim's Word2Vec

by Melanie Beck · The Gensim library is a staple of the NLP stack. While it primarily focuses on topic modeling and similarity for documents, it also supports several word embedding algorithms, including what is likely the best-known implementation of Word2Vec. Word embedding models like Word2Vec use unlabeled data to learn vector representations for each token in a corpus. These embeddings can then be used as features in myriad downstream tasks such as classification, clustering, or recommendation systems.
...read more
Jul 7, 2021 · post

Exploring Multi-Objective Hyperparameter Optimization

By Chris and Melanie. The machine learning life cycle is more than data + model = API. We know there is a wealth of subtlety and finesse involved in data cleaning and feature engineering. In the same vein, there is more to model-building than feeding data in and reading off a prediction. ML model building requires thoughtfulness both in terms of which metric to optimize for a given problem, and how best to optimize your model for that metric!
...read more
Jun 9, 2021 ·

Deep Metric Learning for Signature Verification

By Victor and Andrew. TLDR; This post provides an overview of metric learning loss functions (constrastive, triplet, quadruplet, and group loss), and results from applying contrastive and triplet loss to the task of signature verification. A complete list of the posts in this series is outlined below: Pretrained Models as Baselines for Signature Verification -- Part 1: Deep Learning for Automatic Offline Signature Verification: An Introduction Part 2: Pretrained Models as Baselines for Signature Verification Part 3: Deep Metric Learning for Signature Verification In our previous blog post, we discussed how pretrained models can serve as strong baselines for the task of signature verification.
...read more
May 27, 2021 · post

Pretrained Models as a Strong Baseline for Automatic Signature Verification

By Victor and Andrew. Figure 1. Baseline approach for automatic signature verification using pretrained models TLDR; This post describes how pretrained image classification models can be used as strong baselines for the task of signature verification. The full list of posts in the series is outlined below: Pretrained Models as Baselines for Signature Verification -- Part 1: Deep Learning for Automatic Offline Signature Verification: An Introduction Part 2: Pretrained Models as Baselines for Signature Verification Part 3: Deep Metric Learning for Signature Verification As discussed in our introductory blog post, offline signature verification is a biometric verification task that aims to discriminate between genuine and forged samples of handwritten signatures.
...read more

Popular posts

Oct 30, 2019 · newsletter
Exciting Applications of Graph Neural Networks
Nov 14, 2018 · post
Federated learning: distributed machine learning with data locality and privacy
Apr 10, 2018 · post
PyTorch for Recommenders 101
Oct 4, 2017 · post
First Look: Using Three.js for 2D Data Visualization
Aug 22, 2016 · whitepaper
Under the Hood of the Variational Autoencoder (in Prose and Code)
Feb 24, 2016 · post
"Hello world" in Keras (or, Scikit-learn versus Keras)

Reports

In-depth guides to specific machine learning capabilities

Prototypes

Machine learning prototypes and interactive notebooks
Library

NeuralQA

A usable library for question answering on large datasets.
https://neuralqa.fastforwardlabs.com
Notebook

Explain BERT for Question Answering Models

Tensorflow 2.0 notebook to explain and visualize a HuggingFace BERT for Question Answering model.
https://colab.research.google.com/drive/1tTiOgJ7xvy3sjfiFC9OozbjAX1ho8WN9?usp=sharing
Notebooks

NLP for Question Answering

Ongoing posts and code documenting the process of building a question answering model.
https://qa.fastforwardlabs.com
Notebook

Interpretability Revisited: SHAP and LIME

Explore how to use LIME and SHAP for interpretability.
https://colab.research.google.com/drive/1pjPzsw_uZew-Zcz646JTkRDhF2GkPk0N

About

Cloudera Fast Forward is an applied machine learning reseach group.
Cloudera   Blog   Twitter