Estimator Predict
Use the Estimator API to make predictions with the regression model.
We'll cover the following
Chapter Goals:
- Learn how to use an
Estimator
object to make regression predictions
A. Prediction
The Estimator
object provides a function called predict
, which is used for making model predictions.
Like the train
and evaluate
functions, predict
also takes an input data function as its required argument. However, the input data function for predict
does not need to return any labels, since we don’t return the loss when making model predictions. In fact, it is often the case that there are no actual labels for the predicted data observations.
The predict
function will return a generator of dictionaries, where each dictionary corresponds to a predicted data observation. The dictionary consists of the output values specified in the model function’s prediction ExampleSpec
.
preds = regressor.predict(input_fn)
By default, each dictionary contains all the values from the predictions
argument of the ExampleSpec
. We can choose to return specific values by setting the predict_keys
keyword argument of predict
. The argument takes in a list of strings, corresponding to the names of the values we want returned in each dictionary.
preds = regressor.predict(input_fn,predict_keys=['prediction'])
Below is a plot of regression predictions made with Estimator
(run the show_plot() function). The points represent the 2015 admission rates of 50 randomly chosen colleges.
The plot’s z-axis corresponds to admission rate, the y-axis corresponds to average SAT scores from 2013-2015, and the x-axis corresponds to average tuition from 2013-2016.
The regression model was trained on a variety of factors, of which SAT scores and college tuition were the most informative. Its task was to predict a school’s 2015 admission rate.
Points representing actual admission rates are marked in blue, while model predictions are marked in green.
If you hover over any of the points, it will tell you the name of the school and either the actual or predicted admission rate. You can also click the grey icon bar at the top right of the plot to adjust camera settings.
show_plot()
Get hands-on with 1300+ tech skills courses.