Thierry Wendling - 10 months ago 67
R Question

# Predicted probabilities using the bartMachine R package are failure probabilities

If I run a BART model for classification using

`bartMachine`
, the returned
`p_hat_train`
values correspond to failure probabilities rather than success probabilities as done in the initial implementation of BART in the
`BayesTree`
R package.

Here is an example with a simulated binary response:

``````library(bartMachine)
library(BayesTree)
library(logitnorm)

N = 1000
X <- rnorm(N, 0, 1)
p_true <- invlogit(1.5*X)
y <- rbinom(N, 1, p_true)

## bartMachine
fit <- bartMachine(data.frame(X), as.factor(y), num_burn_in = 200,
num_iterations_after_burn_in = 500)
p_hat <- fit\$p_hat_train

## BayesTree
fit2 <- bart(X, as.factor(y), ntree = 50, ndpost = 500)
p_hat2 <- apply(pnorm(fit2\$yhat.train), 2, mean)

par(mfrow = c(2,2))
plot(p_hat, p_true, main = 'p_hat_train with bartMachine')
abline(0, 1, col = 'red')
plot(1 - p_hat, p_true, main = '1 - p_hat_train with bartMachine')
abline(0, 1, col = 'red')
plot(p_hat2, p_true, main = 'pnorm(yhat.train) with BayesTree')
abline(0, 1, col = 'red')
``````

Inspecting the `iris` example from `?bartMachine` suggests that `bartMachine` is estimating the probability that an observation is classified as the first level of the `y` variable, which in your example happens to be 0. To get your desired result, you'll need to specify levels when you convert `y` to a factor, i.e.

``````fit <- bartMachine(data.frame(X), factor(y, levels = c("1", "0")),
num_burn_in = 200,
num_iterations_after_burn_in = 500)
``````

We can see what's going on when we inspect the code for `build_bart_machine`:

``````if (class(y) == "factor" & length(y_levels) == 2) {
y_remaining = ifelse(y == y_levels[1], 1, 0)
pred_type = "classification"
}
``````

And looking at the output from `bartMachine` (using your original specification) shows what's going on:

``````head(cbind(fit\$model_matrix_training_data, y))
#             X y_remaining y
# 1 -0.85093975           0 1
# 2  0.20955263           1 0
# 3  0.66489564           0 1
# 4 -0.09574123           1 0
# 5 -1.22480134           1 0
# 6 -0.36176273           1 0
``````