Blog

Sep 20, 2021 · post

How (and when) to enable early stopping for Gensim's Word2Vec

The Gensim library is a staple of the NLP stack. While it primarily focuses on topic modeling and similarity for documents, it also supports several word embedding algorithms, including what is likely the best-known implementation of Word2Vec.

Word embedding models like Word2Vec use unlabeled data to learn vector representations for each token in a corpus. These embeddings can then be used as features in myriad downstream tasks such as classification, clustering, or recommendation systems. Gensim’s documentation and tutorials are sufficient for straightforward applications, e.g., training on a corpus of documents composed of plain text. But they don’t cover what to do if your use case is more complex, like how to choose the number of training epochs or other hyperparameter values. This is especially crucial if you’re trying to learn vector representations for non-natural language tokens, e.g., learning embeddings for products or users or books or music. And what does early stopping have to do with any of this?

In this post, we’ll first cover why and when you should and shouldn’t do early stopping with Gensim’s Word2Vec, and we’ll finish up with how to do it with code. Because we’re no fun, let’s start with the “don’t” before getting to the “do"s.

Don’t use early stopping to prevent overfitting

What do we mean when we say “early stopping”? Early stopping was developed as a regularization technique to prevent model overfitting. One would monitor the prediction error on both the training set and the validation set during training. As training progressed, both types of errors would reduce, but eventually, the validation error would start to increase. This is a classic instance of overfitting and one solution was early stopping — terminating training at the point when the validation error reached its minimum.

Image credit: paperswithcode

While this is a perfectly valid technique for all kinds of models, Word2Vec isn’t one of them for a couple of reasons:

  1. Training loss in Word2Vec doesn’t mean a whole lot. It’s buggy and misleading; even the Gensim maintainers know this! It’s been on their To-Do list of open Issues for years (and we don’t begrudge them that — it’s a small core team doing good work!).
  2. Even once that issue is resolved, the final training loss has little to do with how well the resulting embeddings will perform on downstream tasks. For example, let’s suppose we trained a Word2Vec model on a small dataset and noted its final training loss. We then trained a second model using the exact same data, but this time we increased the size of the embedding vectors. Model 2 now has more dimensions to “learn.” However, since our dataset was very small, this larger model would simply resort to memorization. This would yield a lower training loss, but would ultimately result in worse embeddings that aren’t able to generalize to unseen data!

The moral of the story is that early stopping for Word2Vec should not be done to prevent overfitting. The concept of overfitting is difficult, if not impossible, to gauge by considering the training loss alone. If overfitting is your concern, embeddings should be assessed against the end application, e.g., the classification, clustering, or recommendation task you ultimately want to use your embeddings for. Experimenting with various vector sizes for a fixed amount of data will do more to prevent overfitting than early stopping.

With that out of the way, let’s look at times when early stopping Word2Vec does make sense.

Do use early stopping to save compute

Instead of preventing overfitting, there are times when we simply wish to interrupt Word2Vec’s training loop before it completes the specified number of epochs in order to conserve computational resources. We thought of two situations where this could come in handy: during hyperparameter optimization, and as a way to train for a sufficiently large number of epochs without having to include epochs as a hyperparameter.

Hyperparameter tuning

Word2Vec has a host of hyperparameters, from the embedding vector size to the learning rate, to more esoteric quantities like the context window size and negative sampling exponent. The defaults stored in Gensim’s Word2Vec implementation come directly from the academic literature in which the authors empirically determined hyperparameter values that work well for most natural language tasks.

But what if you aren’t learning embeddings for word tokens? In the paper Word2Vec applied to Recommendations: Hyperparameters Matter, the authors use Word2Vec to learn embeddings for all kinds of entities: songs in a music queue, articles browsed on a news website, and products purchased on an e-commerce store. They demonstrate that tuning Word2Vec is crucial for achieving useful embeddings.

But hyperparameter optimization is expensive, especially on very large corpora. And Gensim’s implementation, while fast, is not implemented for GPUs, so you’ll want to make the most of the CPUs at your disposal. One way to do this is to identify poorly-performing hyperparameter configurations during the optimization phase and terminate them early.

Here we show the results of early stopping during hyperparameter tuning for Word2Vec trained on product IDs for an e-commerce use case. Each colored curve represents a different hyperparameter configuration. We evaluate the embeddings we’ve learned up to that epoch on a metric appropriate for our downstream task (Recall@10 for providing recommendations). Using sophisticated scheduling algorithms provided by the Ray Tune library (more on this below), we can automatically terminate underperforming configurations in a principled fashion, thus saving loads of compute time: of the fifteen attempted configurations, only six were trained for a full fifty epochs.

Avoid tuning the number of epochs

In the figure above we saw that the top hyperparameter configurations were each trained for fifty epochs. But how do you know how many epochs to train for?

This may seem like a simple question but it quickly becomes trickier when you consider that Word2Vec was designed to train with a learning rate that decays over the course of training. This is all worked out for you behind the scenes in the implementation details, but it means that you absolutely should not try to design your own training loop to control the number of epochs. (Related: see this StackOverflow post in which one of the Gensim maintainers illustrates why you shouldn’t do this.)

One option is to include epochs as a tunable hyperparameter during optimization, but that can quickly explode the number of configurations to validate. Thankfully, there’s another way.

We need to provide the model with a guesstimate of the number of epochs: too few and the model will underfit; too many and nothing really bad will happen — the model’s embeddings will converge and additional training won’t have much effect. Therefore, it’s safer to err on the side of more epochs. However, there are computational costs associated with this choice. Training for many more epochs than is necessary to achieve good results is wasteful and time-consuming. It would be better to train for exactly what you need. But since you don’t know that a priori, this is a time when you might enter in a very large number of epochs and then consider early stopping.

How do we know when to stop? Our goal is to achieve quality embeddings so we can monitor those embeddings on a downstream task and stop training once those metrics plateau. Refer again to the figure above. Only seven hyperparameter configurations moved past the first round of early stopping, and most of those had Recall@10 scores that plateaued after about thirty epochs of training. There may still be some additional upward trend in the blue, purple, and pink runs, but the orange, red and grey look quite flat. It’s likely that many of these runs have trained such that the embeddings are about as good as they’re going to get. We thus could have saved additional compute time by terminating training runs in which our desired metric has plateaued.

Of course, the drawback with this approach is that it only works if you have a downstream task in mind and appropriate metrics that can evaluate embeddings against that task.

Time for code!

As of Gensim 4.0, *2Vec models do not have an early stopping feature. While there has been discussion of including such functionality in the future (see this Issue), it’s not currently on the road map. We’ll need help from another library: Ray Tune. Ray Tune is a Python library for experiment execution and hyperparameter tuning at scale. It also supports state-of-the-art scheduling algorithms for efficiently handling hyperparameter optimization. (Interested in HPO? Check out our deep dive on the subject here!)

Callbacks

In order for Ray Tune to monitor our Word2Vec model and perform early stopping for us, we need to provide it with information on how the model is performing during training. For this, we’ll take advantage of Gensim callbacks. The examples on the Gensim documentation depict how to report information about the model or the training itself (the current epoch, for example). But we need something more sophisticated: we need to monitor how the embeddings we’ve learned so far are performing on a downstream task. This means we need to access the KeyedVectors method of the model during training. This is actually a bit tricky because we need to simulate training completion during which the embeddings are normalized for downstream use. To do this, we make a deepcopy of the model (so that the original model can continue training in the next epoch).

import copy
from gensim.models.callbacks import CallbackAny2Vec

class RecallAtKLogger(CallbackAny2Vec):
    '''Report Recall@K metric at end of each epoch
    
    Computes and reports Recall@K on a validation set with
    a given value of k (number of recommendations to generate). 
    '''
    def __init__(self, validation, k=10, ray=False):
        self.epoch = 0
        self.validation = validation
        self.k = k
        self.ray = ray

    def on_epoch_end(self, model):
        # make deepcopy of the model and emulate training completion
        mod = copy.deepcopy(model)
        mod._clear_post_train() 
        
        # compute the metric we care about on a recommendation task
			  # with the validation set using the model's embedding vectors
        score = 0
        for query_item, ground_truth in self.validation:
            try:
                # get the k most similar items to the query item
                neighbors = mod.wv.similar_by_vector(query_item, topn=self.k)
            except KeyError:
                pass
            else:
                recommendations = [item for item, distance in neighbors]
                if ground_truth in recommendations:
                    score += 1
        score /= len(self.validation)        
        
        if self.ray:
            tune.report(recall_at_k = score) 
        else:
            print(f"Epoch {self.epoch} -- Recall@{self.k}: {score}")
        self.epoch += 1

Early Stopping with Ray Tune

Once we have our Gensim Callback, early stopping with Ray Tune is easy!

from gensim.models.word2vec import Word2Vec
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.stopper import TrialPlateauStopper

# this helper function is what Tune uses to optimize hyperparameters
def tune_w2v(hyperparameters:dict):
    """ Hyperparameter optimization wrapper for Ray Tune"""
		# instantiate our callback logger
    ratk_logger = RecallAtKLogger(valid, k=10, ray_tune=True) 
		# instantiate our model and pass our callback and hyperparameters
    model = Word2Vec(sentences=train, callbacks=[ratk_logger], **hyperparameters)

# define the search space of hyperparameters (we're performing a random search)
# we conclude the search space with some fixed hyperparameters that we want 
# passed to the model every time, including the number of epochs ("iter")
search_space = {
    "size":         tune.randint(10, 100),
    "window":       tune.randint(3, 25),
    "ns_exponent":  tune.quniform(-1.0, 1.0, .2),
    "alpha":        tune.loguniform(1e-4, 1e-2),
    "negative":     tune.randint(1, 25),
    "iter": 50,     # number of epochs
    "min_count": 1, # number of instances a token appears before word2vec creates an embedding vector for it
    "workers": 6,   # number of CPU workers
    "sg": 1,        # trains the skip-gram version of word2vec
}

# ASHA terminates under-performing hyperparameter configurations in a principled way
asha_scheduler = ASHAScheduler(max_t=100, grace_period=10) 

# terminates training sessions in which the monitored metric has plateaued
stopping_criterion = TrialPlateauStopper(metric='recall_at_k', std=0.002) 

# Let Ray Tune optimize!
analysis = tune.run(
    tune_w2v,
    metric="recall_at_k",
    mode="max",
    scheduler=asha_scheduler,
    stop=stopping_criterion,
    verbose=1,
    num_samples=15,       # number of trials to attempt
    config=search_space
)

Look at the results now! We’ve plotted them on the same scale as the figure above and the difference is striking. Not a single training run ran for the full 50 epochs! Some were stopped by the ASHA Scheduler due to suboptimal hyperparameter values; others were stopped due to a plateau in the downstream metric we were monitoring. And if we examine the best performing model in both experiments, each yields embeddings with roughly the same amount of predictive performance for our downstream recommendation task, indicating that we didn’t have to sacrifice quality for speed-up.

There are, of course, many caveats, mostly due to randomness in these experiments. The hyperparameter optimization we perform uses a random search so the resulting models are certainly not identical. Even if we fixed the hyperparameter values (say, via a grid search), there would still be minor differences since Word2Vec is randomly initialized before training.

These methods also rely on a crucial requirement: that you have a downstream task in mind and have chosen a metric suitable for measuring its success. We demonstrated that through a simple recommendations application: we learned product embeddings and then evaluated how successful they were against a set of ground truth user data.

Conclusion

We did it! We performed two kinds of early stopping on Gensim’s Word2Vec in a principled way using the algorithms provided by the Ray Tune hyperparameter optimization library. What embeddings are you going to train next?

Read more

Newer
Sep 21, 2021 · post
Older
Jul 7, 2021 · post

Latest posts

Nov 15, 2022 · newsletter

CFFL November Newsletter

November 2022 Perhaps November conjures thoughts of holiday feasts and festivities, but for us, it’s the perfect time to chew the fat about machine learning! Make room on your plate for a peek behind the scenes into our current research on harnessing synthetic image generation to improve classification tasks. And, as usual, we reflect on our favorite reads of the month. New Research! In the first half of this year, we focused on natural language processing with our Text Style Transfer blog series.
...read more
Nov 14, 2022 · post

Implementing CycleGAN

by Michael Gallaspy · Introduction This post documents the first part of a research effort to quantify the impact of synthetic data augmentation in training a deep learning model for detecting manufacturing defects on steel surfaces. We chose to generate synthetic data using CycleGAN,1 an architecture involving several networks that jointly learn a mapping between two image domains from unpaired examples (I’ll elaborate below). Research from recent years has demonstrated improvement on tasks like defect detection2 and image segmentation3 by augmenting real image data sets with synthetic data, since deep learning algorithms require massive amounts of data, and data collection can easily become a bottleneck.
...read more
Oct 20, 2022 · newsletter

CFFL October Newsletter

October 2022 We’ve got another action-packed newsletter for October! Highlights this month include the re-release of a classic CFFL research report, an example-heavy tutorial on Dask for distributed ML, and our picks for the best reads of the month. Open Data Science Conference Cloudera Fast Forward Labs will be at ODSC West near San Fransisco on November 1st-3rd, 2022! If you’ll be in the Bay Area, don’t miss Andrew and Melanie who will be presenting our recent research on Neutralizing Subjectivity Bias with HuggingFace Transformers.
...read more
Sep 21, 2022 · newsletter

CFFL September Newsletter

September 2022 Welcome to the September edition of the Cloudera Fast Forward Labs newsletter. This month we’re talking about ethics and we have all kinds of goodies to share including the final installment of our Text Style Transfer series and a couple of offerings from our newest research engineer. Throw in some choice must-reads and an ASR demo, and you’ve got yourself an action-packed newsletter! New Research! Ethical Considerations When Designing an NLG System In the final post of our blog series on Text Style Transfer, we discuss some ethical considerations when working with natural language generation systems, and describe the design of our prototype application: Exploring Intelligent Writing Assistance.
...read more
Sep 8, 2022 · post

Thought experiment: Human-centric machine learning for comic book creation

by Michael Gallaspy · This post has a companion piece: Ethics Sheet for AI-assisted Comic Book Art Generation I want to make a comic book. Actually, I want to make tools for making comic books. See, the problem is, I can’t draw too good. I mean, I’m working on it. Check out these self portraits drawn 6 months apart: Left: “Sad Face”. February 2022. Right: “Eyyyy”. August 2022. But I have a long way to go until my illustrations would be considered professional quality, notwithstanding the time it would take me to develop the many other skills needed for making comic books.
...read more
Aug 18, 2022 · newsletter

CFFL August Newsletter

August 2022 Welcome to the August edition of the Cloudera Fast Forward Labs newsletter. This month we’re thrilled to introduce a new member of the FFL team, share TWO new applied machine learning prototypes we’ve built, and, as always, offer up some intriguing reads. New Research Engineer! If you’re a regular reader of our newsletter, you likely noticed that we’ve been searching for new research engineers to join the Cloudera Fast Forward Labs team.
...read more

Popular posts

Oct 30, 2019 · newsletter
Exciting Applications of Graph Neural Networks
Nov 14, 2018 · post
Federated learning: distributed machine learning with data locality and privacy
Apr 10, 2018 · post
PyTorch for Recommenders 101
Oct 4, 2017 · post
First Look: Using Three.js for 2D Data Visualization
Aug 22, 2016 · whitepaper
Under the Hood of the Variational Autoencoder (in Prose and Code)
Feb 24, 2016 · post
"Hello world" in Keras (or, Scikit-learn versus Keras)

Reports

In-depth guides to specific machine learning capabilities

Prototypes

Machine learning prototypes and interactive notebooks
Notebook

ASR with Whisper

Explore the capabilities of OpenAI's Whisper for automatic speech recognition by creating your own voice recordings!
https://colab.research.google.com/github/fastforwardlabs/whisper-openai/blob/master/WhisperDemo.ipynb
Library

NeuralQA

A usable library for question answering on large datasets.
https://neuralqa.fastforwardlabs.com
Notebook

Explain BERT for Question Answering Models

Tensorflow 2.0 notebook to explain and visualize a HuggingFace BERT for Question Answering model.
https://colab.research.google.com/drive/1tTiOgJ7xvy3sjfiFC9OozbjAX1ho8WN9?usp=sharing
Notebooks

NLP for Question Answering

Ongoing posts and code documenting the process of building a question answering model.
https://qa.fastforwardlabs.com

Cloudera Fast Forward Labs

Making the recently possible useful.

Cloudera Fast Forward Labs is an applied machine learning research group. Our mission is to empower enterprise data science practitioners to apply emergent academic research to production machine learning use cases in practical and socially responsible ways, while also driving innovation through the Cloudera ecosystem. Our team brings thoughtful, creative, and diverse perspectives to deeply researched work. In this way, we strive to help organizations make the most of their ML investment as well as educate and inspire the broader machine learning and data science community.

Cloudera   Blog   Twitter

©2022 Cloudera, Inc. All rights reserved.