causal_datasets.Rmd
library(stat545lamke07)
The purpose of the stat545lamke07
package is to be able to quickly test a method of interest on a toy data set. To that end, the functions starting with causal_
aim to create quick causal data sets where one can test a causal inference model.
In causal inference, one is usually interested in whether a treatment \(T\) has a measurable impact on the outcome \(Y\). Often, there is associated data \(X\) for the individuals. Under the Rubin Causal Model, we consider the outcomes \(Y\) to have potential outcomes \(Y_0\) and \(Y_1\), corresponding to the notion of “What would the outcome have been if the patient had received treatment 0 (or treatment 1)?”. Under the Fundamental Problem of Causal Inference, both potential outcomes cannot be observed at the same time, which makes causal effect estimation hard. The general idea is to use the generate_XY()
function to generate the basic data set \(S = (X,Y)\) and then add the treatment effects.
The function causal_XTY_binary()
builds on the generate_XY()
function by selecting a treatment 1 with probability treatment_prob
, and treatment 0 otherwise. The treatment_effect
is then added to \(Y\) when treatment 1 is selected. Note that, however, we can simulate the exact outcome \(Y\) for both \(Y_0\) and \(Y_1\), since we have control over the generative probability model.
df_causal <- causal_XTY_binary(n = 40, mu = 1:4, sigma = rep(2,4),
beta_coefficients = 1:4, treatment_prob = 0.75, treatment_effect = 25)
print(head(df_causal))
#> # A tibble: 6 × 8
#> X1 X2 X3 X4 treatment Y0 Y1 Y_observed
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 -0.342 2.21 4.69 0.505 0 20.2 45.2 20.2
#> 2 0.959 1.76 1.61 3.72 1 24.2 49.2 49.2
#> 3 0.402 3.31 5.10 5.58 0 44.6 69.6 44.6
#> 4 1.48 3.20 4.93 4.66 1 41.3 66.3 66.3
#> 5 2.09 2.40 3.39 1.33 1 22.4 47.4 47.4
#> 6 5.30 2.66 -0.839 1.28 1 13.2 38.2 38.2
If we would like to turn this data set into a purely observational one, we only need to remove the potential outcomes. This allows us to test causal inference techniques where the potential outcomes \(Y_0, Y_1\) are unknown.
df_causal %>%
dplyr::select(-c("Y0", "Y1")) %>%
head()
#> # A tibble: 6 × 6
#> X1 X2 X3 X4 treatment Y_observed
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 -0.342 2.21 4.69 0.505 0 20.2
#> 2 0.959 1.76 1.61 3.72 1 49.2
#> 3 0.402 3.31 5.10 5.58 0 44.6
#> 4 1.48 3.20 4.93 4.66 1 66.3
#> 5 2.09 2.40 3.39 1.33 1 47.4
#> 6 5.30 2.66 -0.839 1.28 1 38.2
Causal inference, however, need not be restricted to binary treatments only. The function causal_XTY_multiple()
extends this notion of treatment to the case where more than one treatment is available. Rather than one single treatment probability, we now require a probability vector treatment_prob
indicating how likely a treatment was chosen. The treatment effect, treatment_effect
will also be a vector that adds the overall treatment effect to the existing outcomes \(Y\), which serves as the baseline outcome. Again, we include the full information in the data set.
df_causal_multiple <- causal_XTY_multiple(n = 40, mu = rep(2, 5), sigma = 1:5,
beta_coefficients = 1:5,
treatment_prob = c(0.4, 0.1, 0.1, 0.2, 0.2),
treatment_effect = 1:5)
df_causal_multiple %>%
head()
#> # A tibble: 6 × 13
#> X1 X2 X3 X4 X5 Y Y1 Y2 Y3 Y4 Y5 treatment
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
#> 1 2.45 4.03 2.89 4.07 5.56 63.2 64.2 65.2 66.2 67.2 68.2 1
#> 2 4.31 2.89 5.81 0.572 5.36 56.6 57.6 58.6 59.6 60.6 61.6 1
#> 3 1.93 1.15 3.78 1.33 1.86 30.2 31.2 32.2 33.2 34.2 35.2 5
#> 4 1.84 2.84 2.40 5.66 2.62 50.4 51.4 52.4 53.4 54.4 55.4 3
#> 5 3.79 1.53 -0.434 6.71 0.646 35.6 36.6 37.6 38.6 39.6 40.6 4
#> 6 2.15 3.37 -3.13 -2.91 -3.82 -31.3 -30.3 -29.3 -28.3 -27.3 -26.3 1
#> # … with 1 more variable: Y_observed <dbl>
Removing columns such as the potential outcomes \(Y_i\) would turn this data set into a causal data set to be used for testing purposes (e.g. a statistical or a machine learning model).
df_causal_multiple %>%
dplyr::select(starts_with("X"), treatment, Y_observed)
#> # A tibble: 40 × 7
#> X1 X2 X3 X4 X5 treatment Y_observed
#> <dbl> <dbl> <dbl> <dbl> <dbl> <int> <dbl>
#> 1 2.45 4.03 2.89 4.07 5.56 1 64.2
#> 2 4.31 2.89 5.81 0.572 5.36 1 57.6
#> 3 1.93 1.15 3.78 1.33 1.86 5 35.2
#> 4 1.84 2.84 2.40 5.66 2.62 3 53.4
#> 5 3.79 1.53 -0.434 6.71 0.646 4 39.6
#> 6 2.15 3.37 -3.13 -2.91 -3.82 1 -30.3
#> 7 1.63 3.53 0.925 -0.455 -1.44 1 3.44
#> 8 1.25 2.21 6.66 9.32 -6.26 1 32.7
#> 9 3.32 3.44 5.18 -0.0732 10.1 4 80.1
#> 10 1.19 -1.83 1.24 -1.36 3.27 4 16.2
#> # … with 30 more rows