統計的因果探索#

因果グラフ(DAG)を用いて、selection on observableの仮定を置いて因果推論する場合、欠落変数バイアスが起きないように適切なDAGを設定する必要がある。

通常、DAGはドメイン知識に基づいて作られるが、データにもとづいてDAGを推定しようとするのが因果探索(causal discovery)の分野。

LiNGAM#

LiNGAM(Linear Non-Gaussian Acyclic Model) は因果探索の代表的な手法で、

  • 構造方程式モデル(線形モデル)で

  • 誤差項が非ガウスに従う

  • 非巡回(Acyclic)グラフを推定する

といった前提をおく手法(Shimizu et al., 2006

モデルの例#

モデルは構造方程式で記述される。まずOutcomeもTreatmentもCovariatesも全部\(x_i\)で表現する。

outcomeを\(x_3\)とし、\(x_1, x_2\)をTreatmentとする。

\[\begin{split} \begin{aligned} & x_1= e_1 \\ & x_2=b_{21} x_1 + e_2 \\ & x_3=b_{31} x_1 + b_{32} x_2+ e_3 \end{aligned} \end{split}\]

この構造方程式モデルを行列表記にすると

\[\begin{split} \left(\begin{array}{l} x_1 \\ x_2 \\ x_3 \end{array}\right)=\left(\begin{array}{ccc} 0 & 0 & 0 \\ b_{21} & 0 & 0 \\ b_{31} & b_{32} & 0 \end{array}\right)\left(\begin{array}{l} x_1 \\ x_2 \\ x_3 \end{array}\right)+\left(\begin{array}{l} e_1 \\ e_2 \\ e_3 \end{array}\right) \end{split}\]

となる。係数行列を\(B\)とおき、それ以外もベクトルにすると

\[ \boldsymbol{x} = B \boldsymbol{x} + \boldsymbol{e} \]

となる。

※非巡回の制約があるため、因果の流れの通りに(最上流を\(x_1\)、最下流を\(x_3\)に)変数を並べると係数行列\(B\)が下三角行列になる

係数行列の推定#

モデルの変形#

\[ \boldsymbol{x} = B \boldsymbol{x} + \boldsymbol{e} \]

を変形すると

\[\begin{split} \begin{align} \boldsymbol{x} &= B \boldsymbol{x} + \boldsymbol{e}\\ \boldsymbol{x} (I - B) &= \boldsymbol{e}\\ \boldsymbol{x} &= (I - B)^{-1} \boldsymbol{e}\\ \end{align} \end{split}\]

(ここで\(I\)は単位行列)

\(A := (I - B)^{-1}\)とすると

\[ \boldsymbol{x} = A \boldsymbol{e} \]

この\(A\)を求めれば良い。

\(A\)の推定では独立成分分析を使用する。

独立成分分析#

主成分分析はもとのデータを各変数の相関が0になるような新しい変数に変換する手法。

ガウス分布に従うデータなら、主成分分析で変数間の関係が独立になる。

非ガウス分布に従うデータだと相関は0になるが独立にはならない

独立成分分析は主成分分析の結果\(x_{pca}\)に対して線形変換を施して新たな変数\(x_{ica}\)を作成する。

\[ x = A_{ica} x_{ica} \]

という分解ができるため、LiNGAMではこれを使う。

ただし、\(B\)が下三角行列なので、\(A^{-1}\)は対角成分が1で、対角成分より上側のすべての要素が0である必要があり、独立成分分析の後にそうした後処理が必要になる

LiNGAMの計算手順#

※ 小川(2020)と同じように実装したはずだが、一部符号が反転してて、あまりうまくいかなかった

適当にデータを生成する

\[\begin{split} \begin{aligned} & x_1= e_1 \\ & x_2= 3 x_1 + e_2 \\ & x_3= 4 x_1 + 5 x_2+ e_3 \end{aligned} \end{split}\]
# 適当にデータを生成する
import numpy as np
import pandas as pd

n = 1000
np.random.seed(0)

# 非ガウスの誤差
e1 = np.random.uniform(size=n)
e2 = np.random.uniform(size=n)
e3 = np.random.uniform(size=n)

# 各変数の生成
x1 = e1
x2 = 3*x1 + e2
x3 = 4*x1 + 5*x2 + e3

# DFにする
df = pd.DataFrame({"x1": x1, "x2": x2, "x3": x3})

続いて、独立成分分析を行う

from sklearn.decomposition import FastICA
ica = FastICA(random_state=0, max_iter=10000).fit(df)

# ICAで求めた行列A
A_ica = ica.mixing_
A_ica_inv = np.linalg.inv(A_ica)

A_ica_inv.round(1)
array([[ 13.8,  17.5,  -3.5],
       [  2.9,  -0.7,   0.1],
       [-10.1,   3.3,   0. ]])

続いて、\(A_{ica}^{-1}\)に対して

  1. 行の順番を変換

  2. 行の大きさを調整

して対角成分が1で対角成分より上側の要素が全部0な行列になるようにする

munkresパッケージのハンガリアンアルゴリズムという対角成分の和を最小にする問題を解く

!pip install munkres
Collecting munkres
  Downloading munkres-1.1.4-py2.py3-none-any.whl (7.0 kB)
Installing collected packages: munkres
Successfully installed munkres-1.1.4
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.0.1 -> 24.3.1
[notice] To update, run: pip install --upgrade pip
# 1. 行の順番を変換
# 絶対値の逆数にして、対角成分の和を最小化する問題に置き換える
A_ica_inv_small = 1 / np.abs(A_ica_inv)

# 対角成分の和を最小にする行の入れ替えを行う
from munkres import Munkres
m = Munkres()
ixs = np.vstack(m.compute(A_ica_inv_small))

# 順番の入れ替え
ixs = ixs[np.argsort(ixs[:, 0]), :]
ixs_perm = ixs[:, 1]
A_ica_inv_perm = np.zeros_like(A_ica_inv)
A_ica_inv_perm[ixs_perm] = A_ica_inv

# 2. 行の大きさを調整
# 対角成分が1になるよう調整
A_ica_inv_perm_adjusted = A_ica_inv_perm / np.diag(A_ica_inv_perm)

A_ica_inv_perm_adjusted.round(1)
array([[ 1. , -0.2, -0. ],
       [-3.5,  1. , -0. ],
       [ 4.8,  5.3,  1. ]])

\(A^{-1} = I - B\)なので\(B= I - A^{-1}\)として\(B\)を求める

I = np.eye(3)
B = I - A_ica_inv_perm_adjusted

B.round(1)
array([[ 0. ,  0.2,  0. ],
       [ 3.5,  0. ,  0. ],
       [-4.8, -5.3,  0. ]])

lingamパッケージで実践#

Pythonだとlingamパッケージがある

!pip install lingam
Hide code cell output
Collecting lingam
  Downloading lingam-1.9.1-py3-none-any.whl (103 kB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/103.1 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 102.4/103.1 kB 2.9 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 103.1/103.1 kB 2.5 MB/s eta 0:00:00
?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/site-packages (from lingam) (1.26.4)
Collecting pygam
  Downloading pygam-0.9.1-py3-none-any.whl (522 kB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/522.0 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━ 430.1/522.0 kB 16.4 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 522.0/522.0 kB 12.9 MB/s eta 0:00:00
?25hRequirement already satisfied: scipy in /usr/local/lib/python3.10/site-packages (from lingam) (1.13.1)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/site-packages (from lingam) (2.2.3)
Requirement already satisfied: statsmodels in /usr/local/lib/python3.10/site-packages (from lingam) (0.14.4)
Requirement already satisfied: semopy in /usr/local/lib/python3.10/site-packages (from lingam) (2.3.11)
Collecting psy
  Downloading psy-0.0.1-py2.py3-none-any.whl (38 kB)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/site-packages (from lingam) (1.5.2)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/site-packages (from lingam) (3.4.2)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/site-packages (from lingam) (3.10.0)
Requirement already satisfied: graphviz in /usr/local/lib/python3.10/site-packages (from lingam) (0.20.3)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/site-packages (from matplotlib->lingam) (4.55.3)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/site-packages (from matplotlib->lingam) (2.9.0.post0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/site-packages (from matplotlib->lingam) (24.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/site-packages (from matplotlib->lingam) (0.12.1)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/site-packages (from matplotlib->lingam) (3.2.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.10/site-packages (from matplotlib->lingam) (1.4.8)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/site-packages (from matplotlib->lingam) (1.3.1)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.10/site-packages (from matplotlib->lingam) (11.1.0)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/site-packages (from pandas->lingam) (2024.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/site-packages (from pandas->lingam) (2024.2)
Collecting progressbar2
  Downloading progressbar2-4.5.0-py3-none-any.whl (57 kB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/57.1 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.1/57.1 kB 22.5 MB/s eta 0:00:00
?25h
Collecting scipy
  Downloading scipy-1.11.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.4 MB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/36.4 MB ? eta -:--:--
     ━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/36.4 MB 60.7 MB/s eta 0:00:01
     ━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.0/36.4 MB 115.7 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━ 15.1/36.4 MB 203.9 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━ 21.4/36.4 MB 190.8 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━ 28.1/36.4 MB 185.5 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━ 34.9/36.4 MB 197.0 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 36.4/36.4 MB 182.7 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 36.4/36.4 MB 182.7 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 36.4/36.4 MB 70.1 MB/s eta 0:00:00
?25h
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/site-packages (from scikit-learn->lingam) (3.5.0)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/site-packages (from scikit-learn->lingam) (1.4.2)
Requirement already satisfied: numdifftools in /usr/local/lib/python3.10/site-packages (from semopy->lingam) (0.9.41)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/site-packages (from semopy->lingam) (1.13.1)
Requirement already satisfied: patsy>=0.5.6 in /usr/local/lib/python3.10/site-packages (from statsmodels->lingam) (1.0.1)
Collecting python-utils>=3.8.1
  Downloading python_utils-3.9.1-py2.py3-none-any.whl (32 kB)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib->lingam) (1.17.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/site-packages (from sympy->semopy->lingam) (1.3.0)
Requirement already satisfied: typing_extensions>3.10.0.2 in /usr/local/lib/python3.10/site-packages (from python-utils>=3.8.1->progressbar2->psy->lingam) (4.12.2)
Installing collected packages: scipy, python-utils, progressbar2, pygam, psy, lingam
  Attempting uninstall: scipy
    Found existing installation: scipy 1.13.1
    Uninstalling scipy-1.13.1:
      Successfully uninstalled scipy-1.13.1
Successfully installed lingam-1.9.1 progressbar2-4.5.0 psy-0.0.1 pygam-0.9.1 python-utils-3.9.1 scipy-1.11.4
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.0.1 -> 24.3.1
[notice] To update, run: pip install --upgrade pip

Tutorial: DirectLiNGAM — LiNGAM 1.8.2 documentation

import lingam

model = lingam.DirectLiNGAM()
model.fit(df)

# adjacency_matrix_ で推定した係数行列Bを見ることができる
print(model.adjacency_matrix_.round(1))
[[0.  0.  0. ]
 [3.  0.  0. ]
 [3.9 5.  0. ]]

参考文献#