The get_predicted() function is a robust, flexible and user-friendly alternative to base R predict() function. Additional features and advantages include availability of uncertainty intervals (CI), bootstrapping, a more intuitive API and the support of more models than base R's predict function. However, although the interface are simplified, it is still very important to read the documentation of the arguments. This is because making "predictions" (a lose term for a variety of things) is a non-trivial process, with lots of caveats and complications. Read the Details section for more information.

get_predicted(x, ...)

# S3 method for lm
get_predicted(
  x,
  data = NULL,
  predict = "expectation",
  iterations = NULL,
  verbose = TRUE,
  ...
)

# S3 method for stanreg
get_predicted(
  x,
  data = NULL,
  predict = "expectation",
  iterations = NULL,
  include_random = TRUE,
  include_smooth = TRUE,
  verbose = TRUE,
  ...
)

Arguments

x

A statistical model (can also be a data.frame, in which case the second argument has to be a model).

...

Other argument to be passed for instance to get_predicted_ci().

data

An optional data frame in which to look for variables with which to predict. If omitted, the data used to fit the model is used.

predict

Can be "link", "expectation" (default), "prediction", or "response". You can see these 4 options for predictions as on a gradient from "close to the model" to "close to the response data". More specifically, the predict argument modulates two things; the scale of the output as well as the type of certainty interval (see the details and examples). More specifically, "link" returns predictions on the model's link-scale (for logistic models, that means the log-odds scale) with a confidence interval (CI). "expectation" (default) also returns confidence intervals, but this time the output is on the response scale (for logistic models, that means probabilities). "predict" also gives an output on the response scale, but this time associated with a prediction interval (PI), which is larger than a confidence interval (though it mostly make sense for linear models). Finally, "response" only differs from the previous option for binomial models where it additionally transforms the predictions into the original response's type (for instance, to a factor). Read more about in the Details section below.

iterations

For Bayesian models, this corresponds to the number of posterior draws. If NULL, will return all the draws (one for each iteration of the model). For frequentist models, if not NULL, will generate bootstrapped draws, from which bootstrapped CIs will be computed. Iterations can be accessed by running as.data.frame() on the output.

verbose

Toggle warnings.

include_random

If TRUE (default), include all random effects in the prediction. If FALSE, don't take them into account. Can also be a formula to specify which random effects to condition on when predicting (passed to the re.form argument). If include_random = TRUE and newdata is provided, make sure to include the random effect variables in newdata as well.

include_smooth

For General Additive Models (GAMs). If FALSE, will fix the value of the smooth to its average, so that the predictions are not depending on it. (default), mean(), or bayestestR::map_estimate().

Value

The fitted values (i.e. predictions for the response). For Bayesian or bootstrapped models (when iterations != NULL), iterations (as columns and observations are rows) can be accessed via as.data.frame.

Details

In insight::get_predicted(), the predict argument jointly modulates two separate concepts, the scale and the uncertainty interval.

Confidence Interval (CI) vs. Prediction Interval (PI))

  • Linear models - lm(): For linear models, Prediction intervals (predict = "prediction") show the range that likely contains the value of a new observation (in what range it is likely to fall), whereas confidence intervals (predict = "expectation" or predict = "link") reflect the uncertainty around the estimated parameters (and gives the range of uncertainty of the regression line). In general, Prediction Intervals (PIs) account for both the uncertainty in the model's parameters, plus the random variation of the individual values. Thus, prediction intervals are always wider than confidence intervals. Moreover, prediction intervals will not necessarily become narrower as the sample size increases (as they do not reflect only the quality of the fit, but also the variability within the data).

  • General Linear models - glm(): For binomial models, prediction intervals are somewhat useless (for instance, for a binomial (bernoulli) model for which the dependent variable is a vector of 1s and 0s, the prediction interval is... [0, 1]).

Having the output is on the scale of the response variable is arguably the most convenient to understand and visualize the relationships. If on the link-scale, no transformation is applied and the values are on the scale of the model. For instance, for a logistic model, the response scale corresponds to the predicted probabilities, whereas the link-scale makes predictions of log-odds (probabilities on the logit scale). Note that, when predict = "response", the probabilities are rounded (so that the prediction corresponds to the most likely outcome).

See also

Examples

data(mtcars)
x <- lm(mpg ~ cyl + hp, data = mtcars)

predictions <- get_predicted(x)
predictions
#> Predicted values:
#> 
#>  [1] 21.21678 21.21678 26.07124 21.21678 15.44448 21.31239 14.10597 26.66401
#>  [9] 26.03299 20.96820 20.96820 15.34888 15.34888 15.34888 14.87083 14.67962
#> [17] 14.39279 26.58752 26.85523 26.60665 25.99475 15.92253 15.92253 14.10597
#> [25] 15.44448 26.58752 26.10948 25.68880 13.74265 19.97387 12.38501 25.76529
#> 
#> NOTE: Confidence intervals, if available, are stored as attributes and can be accessed using `as.data.frame()` on this output.

# Options and methods ---------------------
get_predicted(x, predict = "prediction")
#> Predicted values:
#> 
#>  [1] 21.21678 21.21678 26.07124 21.21678 15.44448 21.31239 14.10597 26.66401
#>  [9] 26.03299 20.96820 20.96820 15.34888 15.34888 15.34888 14.87083 14.67962
#> [17] 14.39279 26.58752 26.85523 26.60665 25.99475 15.92253 15.92253 14.10597
#> [25] 15.44448 26.58752 26.10948 25.68880 13.74265 19.97387 12.38501 25.76529
#> 
#> NOTE: Confidence intervals, if available, are stored as attributes and can be accessed using `as.data.frame()` on this output.

# Get CI
as.data.frame(predictions)
#>                     Predicted        SE    CI_low  CI_high
#> Mazda RX4            21.21678 0.7281647 19.727518 22.70605
#> Mazda RX4 Wag        21.21678 0.7281647 19.727518 22.70605
#> Datsun 710           26.07124 0.9279509 24.173366 27.96911
#> Hornet 4 Drive       21.21678 0.7281647 19.727518 22.70605
#> Hornet Sportabout    15.44448 0.9200310 13.562810 17.32616
#> Valiant              21.31239 0.7777664 19.721680 22.90310
#> Duster 360           14.10597 1.0080670 12.044237 16.16769
#> Merc 240D            26.66401 0.9225132 24.777260 28.55076
#> Merc 230             26.03299 0.9362657 24.118117 27.94787
#> Merc 280             20.96820 0.6234320 19.693139 22.24326
#> Merc 280C            20.96820 0.6234320 19.693139 22.24326
#> Merc 450SE           15.34888 0.8862558 13.536280 17.16147
#> Merc 450SL           15.34888 0.8862558 13.536280 17.16147
#> Merc 450SLC          15.34888 0.8862558 13.536280 17.16147
#> Cadillac Fleetwood   14.87083 0.8057154 13.222961 16.51871
#> Lincoln Continental  14.67962 0.8206255 13.001249 16.35798
#> Chrysler Imperial    14.39279 0.8911693 12.570146 16.21544
#> Fiat 128             26.58752 0.9099596 24.726448 28.44860
#> Honda Civic          26.85523 0.9695585 24.872258 28.83820
#> Toyota Corolla       26.60665 0.9127445 24.739874 28.47342
#> Toyota Corona        25.99475 0.9454598 24.061069 27.92843
#> Dodge Challenger     15.92253 1.1490264 13.572504 18.27255
#> AMC Javelin          15.92253 1.1490264 13.572504 18.27255
#> Camaro Z28           14.10597 1.0080670 12.044237 16.16769
#> Pontiac Firebird     15.44448 0.9200310 13.562810 17.32616
#> Fiat X1-9            26.58752 0.9099596 24.726448 28.44860
#> Porsche 914-2        26.10948 0.9205392 24.226768 27.99220
#> Lotus Europa         25.68880 1.0474287 23.546572 27.83104
#> Ford Pantera L       13.74265 1.2011595 11.286007 16.19930
#> Ferrari Dino         19.97387 0.7635547 18.412227 21.53552
#> Maserati Bora        12.38501 2.1153615  8.058613 16.71141
#> Volvo 142E           25.76529 1.0175965 23.684073 27.84651

# Bootstrapped
as.data.frame(get_predicted(x, iterations = 4))
#>    Predicted        SE    CI_low  CI_high   iter_1   iter_2    iter_3   iter_4
#> 1   21.40432 0.3480364 21.002728 21.79257 21.39969 21.82240 20.970542 21.42465
#> 2   21.40432 0.3480364 21.002728 21.79257 21.39969 21.82240 20.970542 21.42465
#> 3   26.68824 2.1071592 23.923281 28.45535 27.17221 28.54269 23.659855 27.37819
#> 4   21.40432 0.3480364 21.002728 21.79257 21.39969 21.82240 20.970542 21.42465
#> 5   15.13788 0.6993607 14.449638 16.00809 15.17414 14.88754 16.075708 14.41413
#> 6   21.50666 0.2661665 21.218778 21.82150 21.44688 21.84475 21.200284 21.53475
#> 7   13.70504 0.9691530 12.860329 14.57005 14.51349 14.57464 12.859324 12.87272
#> 8   27.32278 1.5728255 25.262793 28.63473 27.46478 28.68126 25.084254 28.06082
#> 9   26.64730 2.1423408 23.836861 28.44378 27.15333 28.53375 23.567958 27.33415
#> 10  21.13822 0.5763256 20.430601 21.72774 21.27699 21.76429 20.373213 21.13838
#> 11  21.13822 0.5763256 20.430601 21.72774 21.27699 21.76429 20.373213 21.13838
#> 12  15.03554 0.6401225 14.346119 15.79204 15.12695 14.86519 15.845966 14.30403
#> 13  15.03554 0.6401225 14.346119 15.79204 15.12695 14.86519 15.845966 14.30403
#> 14  15.03554 0.6401225 14.346119 15.79204 15.12695 14.86519 15.845966 14.30403
#> 15  14.52381 0.5199296 13.824308 14.88069 14.89101 14.75344 14.697258 13.75353
#> 16  14.31912 0.5784758 13.586160 14.79003 14.79663 14.70874 14.237774 13.53333
#> 17  14.01208 0.7481657 13.228939 14.65405 14.65506 14.64169 13.548549 13.20302
#> 18  27.24090 1.6402586 25.089953 28.61158 27.42703 28.66338 24.900460 27.97274
#> 19  27.52747 1.4071749 25.694894 28.69259 27.55916 28.72596 25.543737 28.28102
#> 20  27.26137 1.6233459 25.133163 28.61737 27.43647 28.66785 24.946409 27.99476
#> 21  26.60636 2.1775842 23.750441 28.43221 27.13446 28.52481 23.476061 27.29011
#> 22  15.64961 1.0691873 14.967235 17.08834 15.41009 14.99929 17.224417 14.96464
#> 23  15.64961 1.0691873 14.967235 17.08834 15.41009 14.99929 17.224417 14.96464
#> 24  13.70504 0.9691530 12.860329 14.57005 14.51349 14.57464 12.859324 12.87272
#> 25  15.13788 0.6993607 14.449638 16.00809 15.17414 14.88754 16.075708 14.41413
#> 26  27.24090 1.6402586 25.089953 28.61158 27.42703 28.66338 24.900460 27.97274
#> 27  26.72917 2.0720426 24.009701 28.46693 27.19108 28.55163 23.751752 27.42223
#> 28  26.27885 2.4614371 23.055655 28.34305 26.98345 28.45329 22.740888 26.93779
#> 29  13.31613 1.2812482 12.021408 14.47804 14.33417 14.48970 11.986305 12.45434
#> 30  20.07382 1.5283530 18.134607 21.47592 20.78622 21.53184 17.983899 19.99334
#> 31  11.86282 2.5413452  8.886492 14.13421 13.66407 14.17233  8.723972 10.89091
#> 32  26.36073 2.3901892 23.231920 28.36277 27.02120 28.47117 22.924681 27.02587
summary(get_predicted(x, iterations = 4)) # Same as as.data.frame(..., keep_iterations = F)
#>    Predicted        SE   CI_low  CI_high
#> 1   20.79277 0.4191720 20.40466 21.32894
#> 2   20.79277 0.4191720 20.40466 21.32894
#> 3   25.80472 1.6752669 24.17555 27.89555
#> 4   20.79277 0.4191720 20.40466 21.32894
#> 5   15.02221 1.4760667 13.08730 16.25470
#> 6   20.87179 0.4824623 20.39369 21.47334
#> 7   13.91592 0.8563141 13.18083 14.97353
#> 8   26.29465 1.4076209 24.74293 27.85412
#> 9   25.77311 1.6949509 24.13895 27.89822
#> 10  20.58731 0.2716659 20.41033 20.95349
#> 11  20.58731 0.2716659 20.41033 20.95349
#> 12  14.94319 1.4180929 13.09398 16.16318
#> 13  14.94319 1.4180929 13.09398 16.16318
#> 14  14.94319 1.4180929 13.09398 16.16318
#> 15  14.54809 1.1493426 13.12739 15.70562
#> 16  14.39005 1.0556934 13.14075 15.52260
#> 17  14.15298 0.9376581 13.16079 15.24806
#> 18  26.23143 1.4375043 24.66972 27.85947
#> 19  26.45269 1.3403888 24.92595 27.84076
#> 20  26.24724 1.4298837 24.68802 27.85813
#> 21  25.74151 1.7148722 24.10235 27.90089
#> 22  15.41732 1.7795023 13.03686 16.72929
#> 23  15.41732 1.7795023 13.03686 16.72929
#> 24  13.91592 0.8563141 13.18083 14.97353
#> 25  15.02221 1.4760667 13.08730 16.25470
#> 26  26.23143 1.4375043 24.66972 27.85947
#> 27  25.83633 1.6558286 24.21216 27.89287
#> 28  25.48864 1.8818604 23.80951 27.92227
#> 29  13.61564 0.8226222 12.81194 14.62578
#> 30  19.76550 0.5808276 19.44717 20.55177
#> 31  12.49355 1.3168882 10.79937 13.50598
#> 32  25.55186 1.8389438 23.88272 27.91693

# Different predicttion types ------------------------
data(iris)
data <- droplevels(iris[1:100, ])

# Fit a logistic model
x <- glm(Species ~ Sepal.Length, data = data, family = "binomial")

# Expectation (default): response scale + CI
pred <- get_predicted(x, predict = "expectation")
head(as.data.frame(pred))
#>    Predicted         SE      CI_low    CI_high
#> 1 0.16579367 0.05943589 0.078854431 0.31573138
#> 2 0.06637193 0.03625646 0.022083989 0.18286787
#> 3 0.02479825 0.01843411 0.005675609 0.10175666
#> 4 0.01498061 0.01261461 0.002839122 0.07513285
#> 5 0.10623680 0.04779474 0.042437982 0.24173444
#> 6 0.48159935 0.07901420 0.333158095 0.63336131

# Prediction: response scale + PI
pred <- get_predicted(x, predict = "prediction")
head(as.data.frame(pred))
#>    Predicted       CI_low      CI_high
#> 1 0.16579367 2.220446e-16 1.000000e+00
#> 2 0.06637193 2.220446e-16 1.000000e+00
#> 3 0.02479825 2.220446e-16 2.220446e-16
#> 4 0.01498061 2.220446e-16 2.220446e-16
#> 5 0.10623680 2.220446e-16 1.000000e+00
#> 6 0.48159935 2.220446e-16 1.000000e+00

# Link: link scale + CI
pred <- get_predicted(x, predict = "link")
head(as.data.frame(pred))
#>     Predicted        SE     CI_low    CI_high
#> 1 -1.61573668 0.4297415 -2.4580146 -0.7734588
#> 2 -2.64380391 0.5850960 -3.7905709 -1.4970369
#> 3 -3.67187114 0.7622663 -5.1658856 -2.1778567
#> 4 -4.18590475 0.8548690 -5.8614172 -2.5103923
#> 5 -2.12977030 0.5033646 -3.1163467 -1.1431939
#> 6 -0.07363584 0.3164854 -0.6939359  0.5466642

# Response: response "type" + PI
pred <- get_predicted(x, predict = "response")
head(as.data.frame(pred))
#>   Predicted CI_low    CI_high
#> 1    setosa setosa versicolor
#> 2    setosa setosa versicolor
#> 3    setosa setosa     setosa
#> 4    setosa setosa     setosa
#> 5    setosa setosa versicolor
#> 6    setosa setosa versicolor