Pythonでジニ係数を計算してみる
データ分析で予測分類モデルとか作るときに、目的変数に対するカテゴリカルな特徴量の寄与をただ知りたいときってありますよね。
Rを使ってたときは、パッケージでジニ係数とか情報利得とかを計算する関数があったんですが、Pythonでは決定木とかに組み込まれているものしか見つけられなかったので作ってみました。
以下の流れでご紹介
ジニ係数って
wikiによれば
ジニ係数(ジニけいすう、Gini coefficient)とは、主に社会における所得分配の不平等さを測る指標。ローレンツ曲線をもとに、1936年、イタリアの統計学者コッラド・ジニによって考案された。所得分配の不平等さ以外にも、富の偏在性やエネルギー消費における不平等さなどに応用される。
とのことですが、データ分析の文脈では分布の不均衡さを表す指標として使います。
Pythonで実装
まずは単一の状態でのジニ係数を計算する関数を実装します。
こいつをある変数に対して分割前の状態と分割後の各状態で計算して、その差分を見るのがデータ分析でのジニ係数の扱いですね。
# -*- encoding=utf-8 -*- import numpy import pandas from collections import Counter def gini(vec): prob2s = list() count = Counter(vec) countall = float(numpy.sum(count.values())) for item in count.items(): counteach = item[1] prob = counteach/countall prob2 = prob**2 prob2s.append(prob2) gini = 1 - numpy.sum(prob2s) print "%s : %f" % (vec.name, gini) return gini
続いて、このginiを呼び出して分割前の目的変数yのジニ係数と、あるカテゴリカル変数xで分割した後のジニ係数を比較します。この差分が正の方向に大きいほど、「xによる分割で各サブセット内のyが均質化した」ことになるので、xは影響力のありそうな特徴量かな、となります。
def giniIndex(x,y): # x,yともにカテゴリカル変数 root_gini = gini(y) grouped = y.groupby(x) leaf_gini_0 = grouped.apply(gini) leaf_weight = grouped.apply(len) / float(len(y)) leaf_gini = leaf_gini_0 * leaf_weight # 0に近いほど均質 = 正例・負例の純度が高いカテゴリが存在 giniindex = root_gini - leaf_gini.sum() print "gini index : %f" %giniindex return giniindex
ついでにxが連続量の場合についても作ってみました。
厳密にはxについて積分が必要な気がしますが、簡易版として100分割して
離散化しています。
def giniIndexNum(x,y, bins=100): # xが連続変数、yがカテゴリカル変数 root_gini = gini(y) xname = x.name yname = y.name xy = pandas.concat([x,y], axis=1) xbins = pandas.cut(xy[xname], bins) grouped = xy[yname].groupby(xbins) leaf_gini_0 = grouped.apply(gini) leaf_weight = grouped.apply(len) / float(len(y)) leaf_gini = leaf_gini_0 * leaf_weight giniindex = root_gini - leaf_gini.sum() print "gini index : %f" %giniindex return giniindex
では実際に使ってみます。*1
x = numpy.random.choice(["A","B","C"], 100) x = pandas.Series(x) y = numpy.random.choice([0,1], 100) y = pandas.Series(y) y.name = "testy" giniIndex(x,y)
これを実行すると、以下のようになります。
testy : 0.499800 A : 0.482422 B : 0.495317 C : 0.499635 gini index : 0.007012
せば
*1:12/26 コードに一部誤りがあったので修正しました。m(_ _)m