point-biserial / biserial#

点双列相関係数(point-biserial correlation)#

点双列相関係数 (point-biserial correlation) は 連続変数と自然な二値変数(連続変数を離散化したわけではない二値変数)の間の相関を測るための相関係数。

点双列相関係数(point-biserial correlation)

連続変数 \(X\) 、二値変数 \(Y \in\{0,1\}\) に対して:

\[ r_{\text{pbi}}=\frac{\bar{X}_1-\bar{X}_0}{s_X} \cdot \sqrt{p q} \]

用語:

  • \(\bar{X}_1, \bar{X}_0: Y=1, Y=0\) のときの \(X\) の平均

  • \(s_X: X\) の標準偏差

  • \(p: Y=1\) の割合

  • \(q=1-p\)

import numpy as np

def point_biserial_correlation(x: np.ndarray, y: np.ndarray) -> float:
    """
    Compute the point-biserial correlation between a continuous variable x
    and a binary variable y (0 or 1), assuming y is a true categorical variable.

    Parameters
    ----------
    x : np.ndarray
        Continuous variable.
    y : np.ndarray
        Binary variable (0 and 1), true categories.

    Returns
    -------
    float
        Point-biserial correlation coefficient.
    """
    x = np.asarray(x)
    y = np.asarray(y)
    
    assert set(np.unique(y)).issubset({0, 1}), "y must be binary (0/1)"
    
    x1 = x[y == 1]
    x0 = x[y == 0]
    
    M1 = np.mean(x1)
    M0 = np.mean(x0)
    s = np.std(x, ddof=0)
    
    p = np.mean(y)
    q = 1 - p

    return (M1 - M0) / s * np.sqrt(p * q)
np.random.seed(0)
x = np.random.normal(size=100)
x_ = x + abs(x.min())
p = (x_ - x_.min()) / x_.max()
y = np.random.binomial(n=1, p=p, size=100)

point_biserial_correlation(x, y)
0.42540375845000344

scipy.stats にも実装がある

point-biserial と Pearsonの積率相関係数は等しい#

\(Y\in\{0,1\}\)のときのピアソンの積率相関係数はpoint-biserialと等しい

Proof of Point-Biserial Correlation being a special case of Pearson Correlation - Cross Validated

\(Y\)が二値変数のため、回帰直線を描くと\(Y=0\)の点の\(X\)の平均\(M_0\)\(Y=1\)の点の\(X\)の平均\(M_1\)の2点の直線になる。

この回帰直線の傾きは\(\beta = M_1 - M_0 / (1 - 0) = M_1 - M_0\)

ピアソンの相関係数の定義は

\[ r=\frac{\operatorname{Cov}(X, Y)}{s_X s_Y} \]

であり、回帰係数の定義から

\[ \beta=\frac{\operatorname{Cov}(X, Y)}{\operatorname{Var}(Y)}=\frac{s_X}{s_Y} r \implies r = \beta \cdot \frac{s_Y}{s_X} \]

であるため

\[ r = M_1 - M_0 \cdot \frac{s_Y}{s_X} =\frac{M_1 - M_0}{s_X} \cdot \sqrt{p(1-p)} \]

これはpoint-biserialに等しい

from scipy.stats import pointbiserialr, pearsonr
pointbiserialr(x, y) == pearsonr(x, y)
True

双列相関係数(biserial correlation)#

バイシリアル相関係数 (biserial correlation, 双列相関係数とも) は、連続変数と人工的に二値化した変数(連続変数を閾値で分けたもの)の間の相関係数。

双列相関係数(biserial correlation)

連続変数 \(X\) と 二値化変数 \(Y \in\{0,1\}\) に対して:

\[ r_{\text{bi}}=\frac{\bar{X}_1-\bar{X}_0}{s_X} \cdot \frac{p q}{\phi(z)} \]

各記号の意味:

  • \(\bar{X}_1, \bar{X}_0\) :連続変数 \(X\) の値のうち,\(Y=1\)\(Y=0\) における平均

  • \(s_X\) :連続変数 \(X\) の全体の標準偏差

  • \(p, q: Y=1, Y=0\) の出現確率 \((p+q=1)\)

  • \(z: Y=1\) に対応する潜在しきい値の標準正規分布におけるZ値(累積確率=\(p\)

  • \(\phi(z)\) :標準正規分布の確率密度関数(PDF)

\[ \phi(z)=\frac{1}{\sqrt{2 \pi}} e^{-z^2 / 2} \]

仮定:

  • \(Y\) が自然なカテゴリ(二値)ではなく、連続変数を人工的にしきい値で切ったものという仮定が必要

  • 連続変数 \(X\) のほうは正規分布に近いことが望ましい

import numpy as np
from scipy.stats import norm

def biserial_correlation(x: np.ndarray, y: np.ndarray) -> float:
    """
    Compute the biserial correlation coefficient between a continuous variable x
    and a dichotomized variable y (0 or 1), assuming y was split from a latent normal variable.

    Parameters
    ----------
    x : np.ndarray
        Continuous variable.
    y : np.ndarray
        Dichotomous variable (0 and 1), assumed to be derived from a latent normal variable.

    Returns
    -------
    float
        Biserial correlation coefficient.
    """
    x = np.asarray(x)
    y = np.asarray(y)

    assert set(np.unique(y)).issubset({0, 1}), "y must be binary (0/1)"
    
    x1 = x[y == 1]
    x0 = x[y == 0]
    M1 = np.mean(x1)
    M0 = np.mean(x0)
    s = np.std(x, ddof=1)

    p = np.mean(y)
    q = 1 - p
    z = norm.ppf(p)
    phi = norm.pdf(z)

    return (M1 - M0) / s * (p * q) / phi
from scipy.stats import multivariate_normal

rho = 0.05
cov = np.array([[1, rho], [rho, 1]])
X = multivariate_normal.rvs(cov=cov, size=100, random_state=0)
x = X[:, 0]
y = 1 * (X[:, 1] > 0.5)
biserial_correlation(x,  y)
0.06162599756903869

背景#

biserial correlationは Pearson (1909)によって提案され、漸近分散はSoper (1914)が導出した

biserialとpoint-biserialのイメージの違い#

  • biserialは人工的な二値変数が対象なので1つの連続値\(y_{\text{latent}}\)の分布をある閾値で切断したものを扱っている

  • point-biserialは自然な二値変数なので2つのクラス\(\{0, 1\}\)の分布はそれぞれ分かれており、重なることもありうるイメージ。

    • 例えば潜在的な能力\(y_{\text{latent}}\)が高い人が正答する(\(y=1\)になる)確率は高いが、100%ではなく偶然誤答することもありうる

Hide code cell source

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
np.random.seed(42)
n = 500
x = np.random.normal(loc=0, scale=1, size=n)
threshold = 0.3
y_biserial = (x > threshold).astype(int)
y_point = (x + np.random.normal(0, 0.5, n) > 0).astype(int)

fig, axs = plt.subplots(1, 2, figsize=(8, 3), sharey=True)

axs[0].hist(x[y_biserial == 0], bins=20, alpha=0.6, label="y=0")
axs[0].hist(x[y_biserial == 1], bins=20, alpha=0.6, label="y=1")
axs[0].axvline(threshold, color='k', linestyle='--', label='Threshold')
axs[0].set_title("Biserial: Artificial Dichotomization")
axs[0].set_xlabel(r"$y_{\text{latent}}$")
axs[0].legend()

axs[1].hist(x[y_point == 0], bins=20, alpha=0.6, label="y=0")
axs[1].hist(x[y_point == 1], bins=20, alpha=0.6, label="y=1")
axs[1].set_title("Point-Biserial: Natural Categories")
axs[1].set_xlabel(r"$y_{\text{latent}}$")
axs[1].legend()

plt.suptitle("Comparison of Biserial vs Point-Biserial Correlation")
plt.tight_layout()
plt.show()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 27
     24 axs[1].legend()
     26 plt.suptitle("Comparison of Biserial vs Point-Biserial Correlation")
---> 27 plt.tight_layout()
     28 plt.show()

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/pyplot.py:2322, in tight_layout(pad, h_pad, w_pad, rect)
   2320 @_copy_docstring_and_deprecators(Figure.tight_layout)
   2321 def tight_layout(*, pad=1.08, h_pad=None, w_pad=None, rect=None):
-> 2322     return gcf().tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/figure.py:3494, in Figure.tight_layout(self, pad, h_pad, w_pad, rect)
   3492 previous_engine = self.get_layout_engine()
   3493 self.set_layout_engine(engine)
-> 3494 engine.execute(self)
   3495 if not isinstance(previous_engine, TightLayoutEngine) \
   3496         and previous_engine is not None:
   3497     _api.warn_external('The figure layout has changed to tight')

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/layout_engine.py:175, in TightLayoutEngine.execute(self, fig)
    173 renderer = fig._get_renderer()
    174 with getattr(renderer, "_draw_disabled", nullcontext)():
--> 175     kwargs = get_tight_layout_figure(
    176         fig, fig.axes, get_subplotspec_list(fig.axes), renderer,
    177         pad=info['pad'], h_pad=info['h_pad'], w_pad=info['w_pad'],
    178         rect=info['rect'])
    179 if kwargs:
    180     fig.subplots_adjust(**kwargs)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/_tight_layout.py:266, in get_tight_layout_figure(fig, axes_list, subplotspec_list, renderer, pad, h_pad, w_pad, rect)
    261         return {}
    262     span_pairs.append((
    263         slice(ss.rowspan.start * div_row, ss.rowspan.stop * div_row),
    264         slice(ss.colspan.start * div_col, ss.colspan.stop * div_col)))
--> 266 kwargs = _auto_adjust_subplotpars(fig, renderer,
    267                                   shape=(max_nrows, max_ncols),
    268                                   span_pairs=span_pairs,
    269                                   subplot_list=subplot_list,
    270                                   ax_bbox_list=ax_bbox_list,
    271                                   pad=pad, h_pad=h_pad, w_pad=w_pad)
    273 # kwargs can be none if tight_layout fails...
    274 if rect is not None and kwargs is not None:
    275     # if rect is given, the whole subplots area (including
    276     # labels) will fit into the rect instead of the
   (...)
    280     # auto_adjust_subplotpars twice, where the second run
    281     # with adjusted rect parameters.

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/_tight_layout.py:82, in _auto_adjust_subplotpars(fig, renderer, shape, span_pairs, subplot_list, ax_bbox_list, pad, h_pad, w_pad, rect)
     80 for ax in subplots:
     81     if ax.get_visible():
---> 82         bb += [martist._get_tightbbox_for_layout_only(ax, renderer)]
     84 tight_bbox_raw = Bbox.union(bb)
     85 tight_bbox = fig.transFigure.inverted().transform_bbox(tight_bbox_raw)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:1415, in _get_tightbbox_for_layout_only(obj, *args, **kwargs)
   1409 """
   1410 Matplotlib's `.Axes.get_tightbbox` and `.Axis.get_tightbbox` support a
   1411 *for_layout_only* kwarg; this helper tries to use the kwarg but skips it
   1412 when encountering third-party subclasses that do not support it.
   1413 """
   1414 try:
-> 1415     return obj.get_tightbbox(*args, **{**kwargs, "for_layout_only": True})
   1416 except TypeError:
   1417     return obj.get_tightbbox(*args, **kwargs)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/axes/_base.py:4387, in _AxesBase.get_tightbbox(self, renderer, call_axes_locator, bbox_extra_artists, for_layout_only)
   4385 for axis in self._axis_map.values():
   4386     if self.axison and axis.get_visible():
-> 4387         ba = martist._get_tightbbox_for_layout_only(axis, renderer)
   4388         if ba:
   4389             bb.append(ba)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:1415, in _get_tightbbox_for_layout_only(obj, *args, **kwargs)
   1409 """
   1410 Matplotlib's `.Axes.get_tightbbox` and `.Axis.get_tightbbox` support a
   1411 *for_layout_only* kwarg; this helper tries to use the kwarg but skips it
   1412 when encountering third-party subclasses that do not support it.
   1413 """
   1414 try:
-> 1415     return obj.get_tightbbox(*args, **{**kwargs, "for_layout_only": True})
   1416 except TypeError:
   1417     return obj.get_tightbbox(*args, **kwargs)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/axis.py:1336, in Axis.get_tightbbox(self, renderer, for_layout_only)
   1334 # take care of label
   1335 if self.label.get_visible():
-> 1336     bb = self.label.get_window_extent(renderer)
   1337     # for constrained/tight_layout, we want to ignore the label's
   1338     # width/height because the adjustments they make can't be improved.
   1339     # this code collapses the relevant direction
   1340     if for_layout_only:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:959, in Text.get_window_extent(self, renderer, dpi)
    954     raise RuntimeError(
    955         "Cannot get window extent of text w/o renderer. You likely "
    956         "want to call 'figure.draw_without_rendering()' first.")
    958 with cbook._setattr_cm(self.figure, dpi=dpi):
--> 959     bbox, info, descent = self._get_layout(self._renderer)
    960     x, y = self.get_unitless_position()
    961     x, y = self.get_transform().transform((x, y))

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:386, in Text._get_layout(self, renderer)
    384 clean_line, ismath = self._preprocess_math(line)
    385 if clean_line:
--> 386     w, h, d = _get_text_metrics_with_cache(
    387         renderer, clean_line, self._fontproperties,
    388         ismath=ismath, dpi=self.figure.dpi)
    389 else:
    390     w = h = d = 0

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:97, in _get_text_metrics_with_cache(renderer, text, fontprop, ismath, dpi)
     94 """Call ``renderer.get_text_width_height_descent``, caching the results."""
     95 # Cached based on a copy of fontprop so that later in-place mutations of
     96 # the passed-in argument do not mess up the cache.
---> 97 return _get_text_metrics_with_cache_impl(
     98     weakref.ref(renderer), text, fontprop.copy(), ismath, dpi)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:105, in _get_text_metrics_with_cache_impl(renderer_ref, text, fontprop, ismath, dpi)
    101 @functools.lru_cache(4096)
    102 def _get_text_metrics_with_cache_impl(
    103         renderer_ref, text, fontprop, ismath, dpi):
    104     # dpi is unused, but participates in cache invalidation (via the renderer).
--> 105     return renderer_ref().get_text_width_height_descent(text, fontprop, ismath)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:230, in RendererAgg.get_text_width_height_descent(self, s, prop, ismath)
    226     return super().get_text_width_height_descent(s, prop, ismath)
    228 if ismath:
    229     ox, oy, width, height, descent, font_image = \
--> 230         self.mathtext_parser.parse(s, self.dpi, prop)
    231     return width, height, descent
    233 font = self._prepare_font(prop)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/mathtext.py:226, in MathTextParser.parse(self, s, dpi, prop)
    222 # lru_cache can't decorate parse() directly because prop
    223 # is mutable; key the cache using an internal copy (see
    224 # text._get_text_metrics_with_cache for a similar case).
    225 prop = prop.copy() if prop is not None else None
--> 226 return self._parse_cached(s, dpi, prop)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/mathtext.py:247, in MathTextParser._parse_cached(self, s, dpi, prop)
    244 if self._parser is None:  # Cache the parser globally.
    245     self.__class__._parser = _mathtext.Parser()
--> 247 box = self._parser.parse(s, fontset, fontsize, dpi)
    248 output = _mathtext.ship(box)
    249 if self._output_type == "vector":

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/_mathtext.py:1995, in Parser.parse(self, s, fonts_object, fontsize, dpi)
   1992     result = self._expression.parseString(s)
   1993 except ParseBaseException as err:
   1994     # explain becomes a plain method on pyparsing 3 (err.explain(0)).
-> 1995     raise ValueError("\n" + ParseException.explain(err, 0)) from None
   1996 self._state_stack = None
   1997 self._in_subscript_or_superscript = False

ValueError: 
y_{\text{latent}}
   ^
ParseSyntaxException: Unknown symbol: \text, found '\'  (at char 3), (line:1, col:4)
Error in callback <function _draw_all_if_interactive at 0x7f2832ad6320> (for post_execute), with arguments args (),kwargs {}:
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/pyplot.py:120, in _draw_all_if_interactive()
    118 def _draw_all_if_interactive():
    119     if matplotlib.is_interactive():
--> 120         draw_all()

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/_pylab_helpers.py:132, in Gcf.draw_all(cls, force)
    130 for manager in cls.get_all_fig_managers():
    131     if force or manager.canvas.figure.stale:
--> 132         manager.canvas.draw_idle()

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/backend_bases.py:2078, in FigureCanvasBase.draw_idle(self, *args, **kwargs)
   2076 if not self._is_idle_drawing:
   2077     with self._idle_draw_cntx():
-> 2078         self.draw(*args, **kwargs)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:400, in FigureCanvasAgg.draw(self)
    396 # Acquire a lock on the shared font cache.
    397 with RendererAgg.lock, \
    398      (self.toolbar._wait_cursor_for_draw_cm() if self.toolbar
    399       else nullcontext()):
--> 400     self.figure.draw(self.renderer)
    401     # A GUI class may be need to update a window using this draw, so
    402     # don't forget to call the superclass.
    403     super().draw()

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:95, in _finalize_rasterization.<locals>.draw_wrapper(artist, renderer, *args, **kwargs)
     93 @wraps(draw)
     94 def draw_wrapper(artist, renderer, *args, **kwargs):
---> 95     result = draw(artist, renderer, *args, **kwargs)
     96     if renderer._rasterizing:
     97         renderer.stop_rasterizing()

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:72, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69     if artist.get_agg_filter() is not None:
     70         renderer.start_filter()
---> 72     return draw(artist, renderer)
     73 finally:
     74     if artist.get_agg_filter() is not None:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/figure.py:3125, in Figure.draw(self, renderer)
   3122         # ValueError can occur when resizing a window.
   3124 self.patch.draw(renderer)
-> 3125 mimage._draw_list_compositing_images(
   3126     renderer, self, artists, self.suppressComposite)
   3128 for sfig in self.subfigs:
   3129     sfig.draw(renderer)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/image.py:131, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    129 if not_composite or not has_images:
    130     for a in artists:
--> 131         a.draw(renderer)
    132 else:
    133     # Composite any adjacent images together
    134     image_group = []

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:72, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69     if artist.get_agg_filter() is not None:
     70         renderer.start_filter()
---> 72     return draw(artist, renderer)
     73 finally:
     74     if artist.get_agg_filter() is not None:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/axes/_base.py:3066, in _AxesBase.draw(self, renderer)
   3063 if artists_rasterized:
   3064     _draw_rasterized(self.figure, artists_rasterized, renderer)
-> 3066 mimage._draw_list_compositing_images(
   3067     renderer, self, artists, self.figure.suppressComposite)
   3069 renderer.close_group('axes')
   3070 self.stale = False

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/image.py:131, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    129 if not_composite or not has_images:
    130     for a in artists:
--> 131         a.draw(renderer)
    132 else:
    133     # Composite any adjacent images together
    134     image_group = []

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:72, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69     if artist.get_agg_filter() is not None:
     70         renderer.start_filter()
---> 72     return draw(artist, renderer)
     73 finally:
     74     if artist.get_agg_filter() is not None:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/axis.py:1379, in Axis.draw(self, renderer, *args, **kwargs)
   1377 # Shift label away from axes to avoid overlapping ticklabels.
   1378 self._update_label_position(renderer)
-> 1379 self.label.draw(renderer)
   1381 self._update_offset_text_position(tlb1, tlb2)
   1382 self.offsetText.set_text(self.major.formatter.get_offset())

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:72, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69     if artist.get_agg_filter() is not None:
     70         renderer.start_filter()
---> 72     return draw(artist, renderer)
     73 finally:
     74     if artist.get_agg_filter() is not None:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:752, in Text.draw(self, renderer)
    749 renderer.open_group('text', self.get_gid())
    751 with self._cm_set(text=self._get_wrapped_text()):
--> 752     bbox, info, descent = self._get_layout(renderer)
    753     trans = self.get_transform()
    755     # don't use self.get_position here, which refers to text
    756     # position in Text:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:386, in Text._get_layout(self, renderer)
    384 clean_line, ismath = self._preprocess_math(line)
    385 if clean_line:
--> 386     w, h, d = _get_text_metrics_with_cache(
    387         renderer, clean_line, self._fontproperties,
    388         ismath=ismath, dpi=self.figure.dpi)
    389 else:
    390     w = h = d = 0

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:97, in _get_text_metrics_with_cache(renderer, text, fontprop, ismath, dpi)
     94 """Call ``renderer.get_text_width_height_descent``, caching the results."""
     95 # Cached based on a copy of fontprop so that later in-place mutations of
     96 # the passed-in argument do not mess up the cache.
---> 97 return _get_text_metrics_with_cache_impl(
     98     weakref.ref(renderer), text, fontprop.copy(), ismath, dpi)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:105, in _get_text_metrics_with_cache_impl(renderer_ref, text, fontprop, ismath, dpi)
    101 @functools.lru_cache(4096)
    102 def _get_text_metrics_with_cache_impl(
    103         renderer_ref, text, fontprop, ismath, dpi):
    104     # dpi is unused, but participates in cache invalidation (via the renderer).
--> 105     return renderer_ref().get_text_width_height_descent(text, fontprop, ismath)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:230, in RendererAgg.get_text_width_height_descent(self, s, prop, ismath)
    226     return super().get_text_width_height_descent(s, prop, ismath)
    228 if ismath:
    229     ox, oy, width, height, descent, font_image = \
--> 230         self.mathtext_parser.parse(s, self.dpi, prop)
    231     return width, height, descent
    233 font = self._prepare_font(prop)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/mathtext.py:226, in MathTextParser.parse(self, s, dpi, prop)
    222 # lru_cache can't decorate parse() directly because prop
    223 # is mutable; key the cache using an internal copy (see
    224 # text._get_text_metrics_with_cache for a similar case).
    225 prop = prop.copy() if prop is not None else None
--> 226 return self._parse_cached(s, dpi, prop)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/mathtext.py:247, in MathTextParser._parse_cached(self, s, dpi, prop)
    244 if self._parser is None:  # Cache the parser globally.
    245     self.__class__._parser = _mathtext.Parser()
--> 247 box = self._parser.parse(s, fontset, fontsize, dpi)
    248 output = _mathtext.ship(box)
    249 if self._output_type == "vector":

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/_mathtext.py:1995, in Parser.parse(self, s, fonts_object, fontsize, dpi)
   1992     result = self._expression.parseString(s)
   1993 except ParseBaseException as err:
   1994     # explain becomes a plain method on pyparsing 3 (err.explain(0)).
-> 1995     raise ValueError("\n" + ParseException.explain(err, 0)) from None
   1996 self._state_stack = None
   1997 self._in_subscript_or_superscript = False

ValueError: 
y_{\text{latent}}
   ^
ParseSyntaxException: Unknown symbol: \text, found '\'  (at char 3), (line:1, col:4)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/IPython/core/formatters.py:402, in BaseFormatter.__call__(self, obj)
    400     pass
    401 else:
--> 402     return printer(obj)
    403 # Finally look for special method names
    404 method = get_real_method(obj, self.print_method)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/IPython/core/pylabtools.py:170, in print_figure(fig, fmt, bbox_inches, base64, **kwargs)
    167     from matplotlib.backend_bases import FigureCanvasBase
    168     FigureCanvasBase(fig)
--> 170 fig.canvas.print_figure(bytes_io, **kw)
    171 data = bytes_io.getvalue()
    172 if fmt == 'svg':

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/backend_bases.py:2338, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)
   2332     renderer = _get_renderer(
   2333         self.figure,
   2334         functools.partial(
   2335             print_method, orientation=orientation)
   2336     )
   2337     with getattr(renderer, "_draw_disabled", nullcontext)():
-> 2338         self.figure.draw(renderer)
   2340 if bbox_inches:
   2341     if bbox_inches == "tight":

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:95, in _finalize_rasterization.<locals>.draw_wrapper(artist, renderer, *args, **kwargs)
     93 @wraps(draw)
     94 def draw_wrapper(artist, renderer, *args, **kwargs):
---> 95     result = draw(artist, renderer, *args, **kwargs)
     96     if renderer._rasterizing:
     97         renderer.stop_rasterizing()

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:72, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69     if artist.get_agg_filter() is not None:
     70         renderer.start_filter()
---> 72     return draw(artist, renderer)
     73 finally:
     74     if artist.get_agg_filter() is not None:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/figure.py:3125, in Figure.draw(self, renderer)
   3122         # ValueError can occur when resizing a window.
   3124 self.patch.draw(renderer)
-> 3125 mimage._draw_list_compositing_images(
   3126     renderer, self, artists, self.suppressComposite)
   3128 for sfig in self.subfigs:
   3129     sfig.draw(renderer)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/image.py:131, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    129 if not_composite or not has_images:
    130     for a in artists:
--> 131         a.draw(renderer)
    132 else:
    133     # Composite any adjacent images together
    134     image_group = []

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:72, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69     if artist.get_agg_filter() is not None:
     70         renderer.start_filter()
---> 72     return draw(artist, renderer)
     73 finally:
     74     if artist.get_agg_filter() is not None:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/axes/_base.py:3066, in _AxesBase.draw(self, renderer)
   3063 if artists_rasterized:
   3064     _draw_rasterized(self.figure, artists_rasterized, renderer)
-> 3066 mimage._draw_list_compositing_images(
   3067     renderer, self, artists, self.figure.suppressComposite)
   3069 renderer.close_group('axes')
   3070 self.stale = False

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/image.py:131, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    129 if not_composite or not has_images:
    130     for a in artists:
--> 131         a.draw(renderer)
    132 else:
    133     # Composite any adjacent images together
    134     image_group = []

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:72, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69     if artist.get_agg_filter() is not None:
     70         renderer.start_filter()
---> 72     return draw(artist, renderer)
     73 finally:
     74     if artist.get_agg_filter() is not None:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/axis.py:1379, in Axis.draw(self, renderer, *args, **kwargs)
   1377 # Shift label away from axes to avoid overlapping ticklabels.
   1378 self._update_label_position(renderer)
-> 1379 self.label.draw(renderer)
   1381 self._update_offset_text_position(tlb1, tlb2)
   1382 self.offsetText.set_text(self.major.formatter.get_offset())

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/artist.py:72, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69     if artist.get_agg_filter() is not None:
     70         renderer.start_filter()
---> 72     return draw(artist, renderer)
     73 finally:
     74     if artist.get_agg_filter() is not None:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:752, in Text.draw(self, renderer)
    749 renderer.open_group('text', self.get_gid())
    751 with self._cm_set(text=self._get_wrapped_text()):
--> 752     bbox, info, descent = self._get_layout(renderer)
    753     trans = self.get_transform()
    755     # don't use self.get_position here, which refers to text
    756     # position in Text:

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:386, in Text._get_layout(self, renderer)
    384 clean_line, ismath = self._preprocess_math(line)
    385 if clean_line:
--> 386     w, h, d = _get_text_metrics_with_cache(
    387         renderer, clean_line, self._fontproperties,
    388         ismath=ismath, dpi=self.figure.dpi)
    389 else:
    390     w = h = d = 0

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:97, in _get_text_metrics_with_cache(renderer, text, fontprop, ismath, dpi)
     94 """Call ``renderer.get_text_width_height_descent``, caching the results."""
     95 # Cached based on a copy of fontprop so that later in-place mutations of
     96 # the passed-in argument do not mess up the cache.
---> 97 return _get_text_metrics_with_cache_impl(
     98     weakref.ref(renderer), text, fontprop.copy(), ismath, dpi)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/text.py:105, in _get_text_metrics_with_cache_impl(renderer_ref, text, fontprop, ismath, dpi)
    101 @functools.lru_cache(4096)
    102 def _get_text_metrics_with_cache_impl(
    103         renderer_ref, text, fontprop, ismath, dpi):
    104     # dpi is unused, but participates in cache invalidation (via the renderer).
--> 105     return renderer_ref().get_text_width_height_descent(text, fontprop, ismath)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:230, in RendererAgg.get_text_width_height_descent(self, s, prop, ismath)
    226     return super().get_text_width_height_descent(s, prop, ismath)
    228 if ismath:
    229     ox, oy, width, height, descent, font_image = \
--> 230         self.mathtext_parser.parse(s, self.dpi, prop)
    231     return width, height, descent
    233 font = self._prepare_font(prop)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/mathtext.py:226, in MathTextParser.parse(self, s, dpi, prop)
    222 # lru_cache can't decorate parse() directly because prop
    223 # is mutable; key the cache using an internal copy (see
    224 # text._get_text_metrics_with_cache for a similar case).
    225 prop = prop.copy() if prop is not None else None
--> 226 return self._parse_cached(s, dpi, prop)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/mathtext.py:247, in MathTextParser._parse_cached(self, s, dpi, prop)
    244 if self._parser is None:  # Cache the parser globally.
    245     self.__class__._parser = _mathtext.Parser()
--> 247 box = self._parser.parse(s, fontset, fontsize, dpi)
    248 output = _mathtext.ship(box)
    249 if self._output_type == "vector":

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/matplotlib/_mathtext.py:1995, in Parser.parse(self, s, fonts_object, fontsize, dpi)
   1992     result = self._expression.parseString(s)
   1993 except ParseBaseException as err:
   1994     # explain becomes a plain method on pyparsing 3 (err.explain(0)).
-> 1995     raise ValueError("\n" + ParseException.explain(err, 0)) from None
   1996 self._state_stack = None
   1997 self._in_subscript_or_superscript = False

ValueError: 
y_{\text{latent}}
   ^
ParseSyntaxException: Unknown symbol: \text, found '\'  (at char 3), (line:1, col:4)
<Figure size 800x300 with 2 Axes>

Hide code cell source

# Create scatter plots with color-coded categories for better intuition
fig, axs = plt.subplots(1, 2, figsize=(8, 3), sharey=True)

# Biserial: artificial thresholding
axs[0].scatter(x, y_biserial + np.random.normal(0, 0.02, size=n), alpha=0.4, c=y_biserial, cmap='coolwarm')
axs[0].axvline(threshold, color='k', linestyle='--', label='Threshold')
axs[0].set_title("Biserial: Artificial Dichotomy")
axs[0].set_xlabel("X (continuous)")
axs[0].set_ylabel("Y (binary)")
axs[0].legend()

# Point-biserial: natural binary
axs[1].scatter(x, y_point + np.random.normal(0, 0.02, size=n), alpha=0.4, c=y_point, cmap='coolwarm')
axs[1].set_title("Point-Biserial: Natural Categories")
axs[1].set_xlabel("X (continuous)")

plt.suptitle("Scatter Plot: Biserial vs Point-Biserial Correlation")
plt.tight_layout()
plt.show()
../../../_images/beb723c8315e9ec2b51b44540e08f7b2f2853071b4353bae5e10e5b98a32ffe1.png

Point-Biserialに補正係数をかけたのがBiserial#

point-biserial (=Pearson’s r)

\[ r_{\text{pbi}}=\frac{\bar{X}_1-\bar{X}_0}{s_X} \cdot \sqrt{p q} \]

に対して補正係数\(\frac{\sqrt{pq}}{\phi(z)}\)をかけたのがBiserial

\[ r_{\text{bi}} = r_{\text{pbi}} \cdot \frac{\sqrt{pq}}{\phi(z)} = \frac{\bar{X}_1-\bar{X}_0}{s_X} \cdot \frac{p q}{\phi(z)} \]

Peters & Van Voorhis (1940) は相関係数が\(\rho\)の 2 変量正規分布に従う確率変数 \(X, Y\)があるとき、\(X\)を平均値や中央値で二値化した確率変数を\(X_d\)とすると、\(X_d\)\(Y\)の間の相関係数は\(0.798 \rho\)になる、つまり真の相関係数の約 0.8 倍へと過小評価する問題があることを報告している。

Cohen (1983)によれば、Peters and Van Voorhis (1940)が報告しているような二値化の希薄化の係数は次のように一般化できる

\[ e = \frac{\phi(z)}{\sqrt{p(1 - p)}} \]
  • \(\phi(z)\):標準正規分布の密度関数

  • \(p\):二値化した変数の比率

この希薄化誤差\(e\)の逆数をpoint-biserialに掛けているのがbiserial

import numpy as np
from scipy.stats import pearsonr

# 2変量標準正規分布に従うデータ
rho = 0.75
size = 10000
np.random.seed(0)
data = np.random.multivariate_normal(mean=[0, 0], cov=[[1, rho], [rho, 1]], size=size)
x, y = data[:, 0], data[:, 1]

# 平均値で二値化する
xd = 1 * (x >= np.mean(x))  # 平均で2値化
rho_d = pearsonr(xd, y)[0]
print(f"真のρ={rho:.2f}, 離散化後のpearson={rho_d:.3f}, 比率={rho_d / rho:.3f}")


from scipy.stats import norm
p = xd.mean()
mean_z = 0  # 「標準正規分布」 & 「平均値で二分」の仮定より、閾値=平均値は0
e = norm.pdf(mean_z) / np.sqrt(p * (1 - p))
print(f"{e=:.3f}, 補正後(biserial)={rho_d / e:.3f}")
真のρ=0.75, 離散化後のpearson=0.601, 比率=0.801
e=0.798, 補正後(biserial)=0.753
rho = 0.5
cov = np.array([[1, rho], [rho, 1]])
X = multivariate_normal.rvs(cov=cov, size=100, random_state=0)
x = X[:, 0]
y = 1 * (X[:, 1] > 0.5)
print(f"biserial: {biserial_correlation(x, y):.5f}")
print(f"point_biserial: {point_biserial_correlation(x, y):.5f}")
biserial: 0.42521
point_biserial: 0.32762
from ordinalcorr import polyserial
polyserial(x, y)
np.float64(0.4539866448381744)

biserialとpoint-biserialはどれくらい差が出るのか#

MacCallum, R. C., Zhang, S., Preacher, K. J., & Rucker, D. D. (2002). On the practice of dichotomization of quantitative variables. Psychological methods, 7(1), 19.

補正係数はどういう関数になっているのか#

相関の希薄化(attenuation)の効果は、二値化の分割点以上の値の比率\(p = \operatorname{E}[\mathbb{1}(Y \geq \tau)]\)の関数である

(閾値の標準正規分布上の位置\(z\)が閾値以上の値の比率\(p\)の関数であるため)

\[ e(p) = \frac{\phi(z)}{\sqrt{p(1 - p)}} \]

\(e(p)\)\(p\)が0あるいは1に近いときに極端に小さくなる。

biserial相関係数で使用する補正項は\(1/e(p)\)なので、補正係数が極端に大きくなることに相当する。

Hide code cell source

import numpy as np
from scipy.stats import norm
c = 0.000001
p = np.linspace(0 + c, 1 - c, 500)

def attenuation_effect(p: np.ndarray) -> np.ndarray:
    z = norm.ppf(p)
    e = norm.pdf(z) / np.sqrt(p * (1 - p))
    return e

e = attenuation_effect(p)

import matplotlib.pyplot as plt
# fig, ax = plt.subplots(figsize=[4, 2])
fig, axes = plt.subplots(figsize=[9, 2], ncols=2)
ax = axes[0]
ax.plot(p, e)
ax.set(
    title=r"Attenuation effect $e(p) = \frac{\phi(z)}{\sqrt{p(1-p)}}$",
    ylabel=r"$e(p) = \frac{\phi(z)}{\sqrt{p(1-p)}}$",
    xlabel=r"proportion of population above split point $p$",
    ylim=[0,1]
)

ax = axes[1]
ax.plot(p, 1/e)
ax.set(
    title=r"Inverse of Attenuation effect $e(p)$",
    ylabel=r"$1 / e(p)$",
    xlabel=r"proportion of population above split point $p$",
)

fig.show()
../../../_images/390e9c4e8c34fd11f1ec0eb2f93199a21cf3561b1a0ab1346f69c7837a68f4c4.png

そのため、例えば\(p\)が極めて0や1に近い値のときは補正係数が大きくなりすぎてbiserial相関係数の絶対値が\(1\)を超えることもありうる。

import numpy as np
from scipy.stats import pearsonr

rho = -0.999
n = 100
np.random.seed(0)
data = np.random.multivariate_normal(mean=[0, 0], cov=[[1, rho], [rho, 1]], size=n)
x, y = data[:, 0], data[:, 1]

threshold = 0.999
yd = 1 * (y >= threshold)
p = yd.mean()
e = attenuation_effect(p)

r_pbi = pearsonr(x, yd).statistic
r_bi = r_pbi / e
print(f"{rho=:.3f}, {p=:.3f}, {r_pbi=:.3f}, {r_bi=:.3f}")
rho=-0.999, p=0.170, r_pbi=-0.714, r_bi=-1.060

参考文献#