Estimator Predict

Use the Estimator API to make predictions with the regression model.

We'll cover the following

Chapter Goals:

  • Learn how to use an Estimator object to make regression predictions

A. Prediction

The Estimator object provides a function called predict, which is used for making model predictions.

Like the train and evaluate functions, predict also takes an input data function as its required argument. However, the input data function for predict does not need to return any labels, since we don’t return the loss when making model predictions. In fact, it is often the case that there are no actual labels for the predicted data observations.

The predict function will return a generator of dictionaries, where each dictionary corresponds to a predicted data observation. The dictionary consists of the output values specified in the model function’s prediction ExampleSpec.

Press + to interact
preds = regressor.predict(input_fn)

By default, each dictionary contains all the values from the predictions argument of the ExampleSpec. We can choose to return specific values by setting the predict_keys keyword argument of predict. The argument takes in a list of strings, corresponding to the names of the values we want returned in each dictionary.

Press + to interact
preds = regressor.predict(
input_fn,
predict_keys=['prediction'])

Below is a plot of regression predictions made with Estimator (run the show_plot() function). The points represent the 2015 admission rates of 50 randomly chosen colleges.

The plot’s z-axis corresponds to admission rate, the y-axis corresponds to average SAT scores from 2013-2015, and the x-axis corresponds to average tuition from 2013-2016.

The regression model was trained on a variety of factors, of which SAT scores and college tuition were the most informative. Its task was to predict a school’s 2015 admission rate.

Points representing actual admission rates are marked in blue, while model predictions are marked in green.

If you hover over any of the points, it will tell you the name of the school and either the actual or predicted admission rate. You can also click the grey icon bar at the top right of the plot to adjust camera settings.

Press + to interact
show_plot()

Get hands-on with 1300+ tech skills courses.