caual inference using Python

Published

December 25, 2022

导入数据

import pandas as pd

df = pd.read_csv('data.csv')
df
y istreatment x1 x2 x3
0 4.636388 1 -0.355052 0.441348 0.908629
1 -1.965486 0 -0.819260 -0.712998 0.037563
2 0.581781 0 1.391339 -0.017292 -0.804188
3 -2.067287 0 -0.831021 0.497860 0.349555
4 9.546829 1 1.682321 0.608986 0.937725
... ... ... ... ... ...
4995 5.229174 1 0.177594 0.565183 0.159337
4996 2.842308 1 0.549753 -0.912549 0.046224
4997 6.659550 1 1.359027 1.181659 -1.893093
4998 1.016941 0 -2.103881 0.543803 0.962677
4999 3.323807 1 1.627608 -0.923482 0.194445

数据描述 - x1,x2,x3 协变量(控制变量) - y 因变量 - istreatment 处置变量D,标注每条数据隶属于treatment或control组。1为treatment, 0为control。

from causalinference import CausalModel

Y = df['y'].values
D = df['istreatment'].values
X = df[['x1', 'x2', 'x3']].values

#CausalModel参数依次为Y, D, X。其中Y为因变量
causal = CausalModel(Y, D, X)
causal
<causalinference.causal.CausalModel at 0x1037b90a0>
print(causal.summary_stats)

Summary Statistics

                      Controls (N_c=2509)        Treated (N_t=2491)             
       Variable         Mean         S.d.         Mean         S.d.     Raw-diff
--------------------------------------------------------------------------------
              Y       -1.012        1.742        4.978        3.068        5.989

                      Controls (N_c=2509)        Treated (N_t=2491)             
       Variable         Mean         S.d.         Mean         S.d.     Nor-diff
--------------------------------------------------------------------------------
             X0       -0.343        0.940        0.336        0.961        0.714
             X1       -0.347        0.936        0.345        0.958        0.730
             X2       -0.313        0.940        0.306        0.963        0.650
causal.summary_stats.keys()
dict_keys(['N', 'K', 'N_c', 'N_t', 'Y_c_mean', 'Y_t_mean', 'Y_c_sd', 'Y_t_sd', 'rdiff', 'X_c_mean', 'X_t_mean', 'X_c_sd', 'X_t_sd', 'ndiff'])

使用OLS估计处置效应

估计处置效应最简单的方法是使用OLS方法,

所使用的函数为.est_via_ols(),其中还有adj参数,即模型是否使用了协变量,D与X的交互效应。

Yi=α+βDi+γ(XiX¯)+δDi(XiX¯)+εi

causal.est_via_ols(adj=2)
print(causal.estimates)

Treatment Effect Estimates: OLS

                     Est.       S.e.          z      P>|z|      [95% Conf. int.]
--------------------------------------------------------------------------------
           ATE      3.017      0.034     88.740      0.000      2.950      3.083
           ATC      2.031      0.040     51.183      0.000      1.953      2.108
           ATT      4.010      0.039    103.964      0.000      3.934      4.086
/Users/a182501/opt/miniconda3/lib/python3.9/site-packages/causalinference/estimators/ols.py:21: FutureWarning: `rcond` parameter will change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions.
To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`.
  olscoef = np.linalg.lstsq(Z, Y)[0]
  • 其中的ATE为平均处置效应
  • ATC为控制组的平均处置效应
  • ATT为实验组的平均处置效应

PSM

causal.est_propensity_s()
print(causal.propensity)

Estimated Parameters of Propensity Score

                    Coef.       S.e.          z      P>|z|      [95% Conf. int.]
--------------------------------------------------------------------------------
     Intercept      0.005      0.035      0.145      0.885     -0.063      0.073
            X1      0.999      0.041     24.495      0.000      0.919      1.079
            X0      1.000      0.041     24.543      0.000      0.920      1.080
            X2      0.933      0.040     23.181      0.000      0.855      1.012

分层方法

causal.stratify_s()  
print(causal.strata) 

Stratification Summary

              Propensity Score         Sample Size     Ave. Propensity   Outcome
   Stratum      Min.      Max.  Controls   Treated  Controls   Treated  Raw-diff
--------------------------------------------------------------------------------
         1     0.001     0.043       153         5     0.024     0.029    -0.049
         2     0.043     0.069       148         8     0.056     0.059     0.142
         3     0.070     0.118       283        29     0.093     0.092     0.953
         4     0.119     0.178       268        45     0.147     0.147     1.154
         5     0.178     0.240       247        65     0.208     0.210     1.728
         6     0.240     0.361       451       174     0.299     0.300     2.093
         7     0.361     0.427       196       117     0.393     0.395     2.406
         8     0.427     0.499       153       159     0.465     0.464     2.868
         9     0.499     0.532        82        75     0.515     0.515     2.973
        10     0.532     0.568        65        91     0.551     0.553     3.259
        11     0.568     0.630       114       198     0.600     0.601     3.456
        12     0.630     0.758       180       445     0.693     0.696     3.918
        13     0.758     0.818        77       236     0.787     0.789     4.503
        14     0.818     0.876        57       255     0.845     0.849     4.937
        15     0.876     0.933        23       289     0.904     0.904     5.171
        16     0.933     0.998        12       300     0.957     0.963     6.822