This vignette will present how to generate predictions using estimate. Warning: we will go full Bayesian. If you’re not familiar with the Bayesian framework, we recommend starting with this gentle introduction.

Prediction against original data

Generating prediction from the model can be used for a wide variety of reasons, one of them being visualisation. This can be achieved via the estimate_response() function and its visualisation spinoff, estimate_link().

Let’s start by fitting a Bayesian linear regression.

library(rstanarm)

data <- iris
model <- stan_glm(Petal.Length ~ Sepal.Length, data = data)

We might be interested in comparing the values predicted by the model to the actual “true” values. This can be done by generating predictions:

library(modelbased)

predicted <- estimate_response(model)
head(predicted)
>   Sepal.Length Predicted   SE CI_low CI_high
> 1          5.1       2.4 0.87   0.62     4.1
> 2          4.9       2.0 0.87   0.17     3.6
> 3          4.7       1.6 0.90  -0.12     3.4
> 4          4.6       1.5 0.87  -0.35     3.2
> 5          5.0       2.2 0.88   0.52     3.9
> 6          5.4       2.9 0.89   1.19     4.7

The output is a data frame containing predicted values (the median and CI of the posterior distribution) for each of the value of the original data frame (used for fitting the model). Hence, we can simply add the prediction column (Predicted) to the original dataset and plot the original against the predicted data (on top of the identity line, representing the perfect relationship).

library(ggplot2)
library(dplyr)
library(see)

data$Predicted <- predicted$Predicted

data %>%
  ggplot(aes(x = Petal.Length, y = Predicted)) +
  geom_line(aes(x = Petal.Length, y = Petal.Length), linetype = "dashed") +
  geom_point() +
  ylab("Petal.Length (predicted)") +
  theme_modern()

It seems like our model does not perform too bad. What if we added information about the Species in the model?

model <- stan_glm(Petal.Length ~ Sepal.Length * Species, data = data)
data$Predicted_2 <- estimate_response(model)$Predicted

We could now plot the second observations, based on a more complex model, as a red overlay to the previous points:

data %>%
  ggplot() +
  geom_line(aes(x = Petal.Length, y = Petal.Length), linetype = "dashed") +
  geom_point(aes(x = Petal.Length, y = Predicted), color = "grey") +
  geom_point(aes(x = Petal.Length, y = Predicted_2), color = "red") +
  ylab("Petal.Length (predicted)") +
  theme_modern()

The new model generated much more accurate predictions (closer from the underlying regression line).

Different CI levels

The purpose of CI bands is to provide information about the uncertainty related to the estimation. In the Bayesian framework, the credible intervals are directly related to the shape of the posterior distribution. Thus, showing different CI levels (for instance, 69%, 89% and 99%).

predicted <- estimate_link(model, ci = c(0.69, .89, 0.99))

iris %>%
  ggplot(aes(x = Sepal.Length)) +
  geom_point(aes(y = Petal.Length, color = Species)) +
  geom_ribbon(data = predicted, aes(ymin = CI_low_0.99, ymax = CI_high_0.99, fill = Species), alpha = 0.2) +
  geom_ribbon(data = predicted, aes(ymin = CI_low_0.89, ymax = CI_high_0.89, fill = Species), alpha = 0.3) +
  geom_ribbon(data = predicted, aes(ymin = CI_low_0.69, ymax = CI_high_0.69, fill = Species), alpha = 0.3) +
  geom_line(data = predicted, aes(y = Predicted, color = Species), size = 1) +
  theme_modern()

Adding individual draws

Instead (or in addition to) representing credible/confidence intervals, the Bayesian framework also allow to represent every individual posterior draw. In this case, they correspond to all possible links estimated by the model. In it a nice insight into the “true” underlying probabilities in addition to summaries like the median or the CI.

# Keep only 100 draws (keeping all the draws might be slower)
predicted <- estimate_link(model, keep_iterations = TRUE, iterations = 100)

# Format draws for plotting
iterations <- bayestestR::reshape_iterations(predicted)
iterations$group <- paste0(iterations$iter_group, iterations$Species)

iris %>%
  ggplot(aes(x = Sepal.Length)) +
  geom_point(aes(y = Petal.Length, color = Species)) +
  geom_line(data = iterations, aes(y = iter_value, color = Species, group = group), alpha = 0.05) +
  geom_line(data = predicted, aes(y = Predicted, color = Species), size = 1) +
  theme_modern()

Animated hypothetical outcome plots can also be easily created with gganimate:

library(gganimate)

p <- iris %>%
  ggplot(aes(x = Sepal.Length)) +
  geom_point(aes(y = Petal.Length, color = Species)) +
  geom_line(data = iterations, aes(y = iter_value, color = Species, group = group)) +
  theme_modern() +
  transition_states(iter_group, 0, 1) +
  shadow_mark(past = TRUE, future = TRUE, alpha = 1 / 20, color = "grey")

gganimate::animate(p)

References