XGBoostのRNGをMTに置換える

背景

オプトDSLで開催された「ユーザー離脱予想」のコンペで入賞しました。 結構丁寧に検収をして頂くのですが、オプトの方とこちらとでどうしても結果が一致せずに困り果てていました。

Twitterでつぶやいた所、有益情報をゲットします。

random.hを調べてみたところ…

struct Random{
  inline void Seed(unsigned sd) {
    this->rseed = sd;
    #if defined(_MSC_VER)||defined(_WIN32)
    ::xgboost::random::Seed(sd);
    #endif
}

確かにWinとそれ以外で挙動が違っています。

また、見落としていたのいたのですが、公式に以下の投稿を発見しました。 github.com

そもそもrand()処理系依存なので、異なる環境で結果の一致は望むべくもないとわかりました。

モチベーション

当初の課題は解決したのですが、今度は評判の悪いrand()を使っていて良いのかという点が気になってきました。

何も考えずにMersenne Twisterを使うように教育されてきた人間なので、質の悪い乱数発生器で機械学習するというのには抵抗があります。 MCMC*1やRandom Forestでrand()を使っているものがあればゴミ扱いも免れないでしょう。 しかしながら、GBRTでは新しく学習されるツリーは前回までの学習結果に強烈に依存しているので、乱数の質が問題となりにくいようにも思われます。

とにかく、rand()を使っているせいでKaggleの順位が伸びないのであれば癪なので実験してみました。

XGBoostの改造

乱数を扱っているのは前述のrandom.hですので、ここに手を入れます。

namespace xgboost {
namespace random {
  extern std::mt19937 mt;
  extern bool use_mt;
  inline void Seed(unsigned seed) {
    srand(seed);
    mt.seed(seed);
  }
  inline void set_use_mt(unsigned use) {
    if (use > 0) {
      use_mt = true;
    } else {
      use_mt = false;
    }
  }
  inline double Uniform(void) {
    if (use_mt) {
      return static_cast<double>(mt()) / (static_cast<double>(mt.max())+1.0);
    } else {
      return static_cast<double>(rand()) / (static_cast<double>(RAND_MAX)+1.0);
    }
  }
  inline double NextDouble2(void) {
    if (use_mt) {
      return (static_cast<double>(mt())+1.0) / (static_cast<double>(mt.max())+2.0);
    } else {
      return (static_cast<double>(rand())+1.0) / (static_cast<double>(RAND_MAX)+2.0);
    }
  }
}
};

乱数発生器を切替えるために、どこかでset_use_mt()を実行しなければなりません。 seedを設定している箇所がlearner-inl.hppSetParam()にあるので、そこに処理を追加します。

if (!strcmp("seed", name)) {
  seed = atoi(val); random::Seed(seed);
}
if (!strcmp("use_mt", name)) {
  random::set_use_mt(atoi(val));
}

並列化については何も考えずに設計しているので、OpenMPはOFFにしてビルドします。

数値実験

seedを揃えてrand()版とmt()版でそれぞれ学習をします。 seedを変えてこれを繰り返し精度を比較してみます。

データはKaggleのOttoを用いることにします。

実験用のコードは以下の通り。

prm = {'colsample_bytree': 0.5,
       'eval_metric': 'mlogloss',
       'max_depth': 7,
       'num_class': 9,
       'objective': 'multi:softprob',
       'silent': 0,
       'subsample': 0.9}
num_round = 30
def fit_XGB(y, x, flg):
    loss = []
    flg  = np.in1d(x.index, idx)
    mat  = xgb.DMatrix(x[flg].values, label=y[flg].values)
    clf  = xgb.train(prm, mat, num_round)
    mat  = xgb.DMatrix(x.values, label=y.values)
    pred = clf.predict(mat)
    loss.append(cf.log_loss_multi(y[flg].values, pred[flg]))
    flg = np.logical_not(flg)
    loss.append(cf.log_loss_multi(y[flg].values, pred[flg]))
    return loss

loss0 = []
loss1 = []
num_iter = 300
for ii in range(num_iter):
    idx   = np.random.choice(x.index, int(0.5*len(x)), replace=False)
    flg   = np.in1d(x.index, idx)
    sd    = int(np.random.random_sample() * 1000000000)
    prm['seed']   = sd
    prm['use_mt'] = 0
    loss0.append(fit_XGB(y, x, flg))
    prm['use_mt'] = 1
    loss1.append(fit_XGB(y, x, flg))
loss_rd = pd.DataFrame(loss0)
loss_rd.columns = ['IN', 'OUT']
loss_mt = pd.DataFrame(loss1)
loss_mt.columns = ['IN', 'OUT']

実行結果は以下のようになりました。

rand()

項目 IN OUT
mean 0.316084 0.545039
std 0.003770 0.003840
min 0.306325 0.535338
25% 0.313512 0.542393
50% 0.316218 0.545169
75% 0.318698 0.547285
max 0.326892 0.556248

mt()

項目 IN OUT
mean 0.316162 0.545059
std 0.003549 0.003830
min 0.306115 0.535763
25% 0.313710 0.542551
50% 0.315968 0.544973
75% 0.318608 0.547787
max 0.324961 0.555478

今回はrand()版の方が良い精度となりました。

結論

試行回数300は少ない、etaパラメータが大きい等に調整の余地はありますが、乱数の質に拘っても精度への貢献は僅かだと予想されます。

本家がMersenne Twisterに替えてくれるならそれで良し、そうでなくても安心して使って良さそうです。

*1:質の悪い乱数でもMarkov Chain側で調整が効くので何とかなるような気もしてきましたが、どうなんでしょう。

Rからパラメータ付きCypherクエリを投げる

KaggleやCrowdSolvingでレコメンのコンペが開催されたときに使いたいなぁと思ってNeo4jの勉強を始めたのですが、グラフDBに適した問題がなかなか出てきません。 今回はNeo4j 2.0がリリースされた記念に記事を書いてみました。

目標

RからCypherクエリを投げて結果をデータセットにします。 いついかなる時もCypherクエリはパラメータ化すべきとのことなので、それにも従います。

使用データ

KaggleのEvent Recommendation Engine Challengeデータを使っています。 BatchInserterを使ってDBに挿入したのですが、重複レコードがたくさんあって大変でした。

CSVファイルの容量は1.5GBくらいですが、Neo4jに放り込むと5.5GBくらいになりました。

参考

Stack Overflowにあったコードをベースにしています。 こういうのも見つけましたが、ざっと眺めた感じパラメータ化はされていないようです。

コード

library(RCurl)
library(RJSONIO)
getQuery <- function (query, params) {
  h  =  basicTextGatherer()
  pf <- toJSON(list(query=query, params=params))
  curlPerform(url="localhost:7474/db/data/cypher",
              httpheader=c("Content-Type"="application/json"),
              customrequest="POST",
              postfields=pf,
              writefunction=h$update)
  result <- fromJSON(h$value())
  if (is.element("exception", names(result))) {
    dat <- result
  } else {
    dat <- data.frame(t(sapply(result$data, function(y) y)))
    if (ncol(dat)==length(result$columns)) {
      names(dat) <- result$columns
    }
  }
  return(dat)
}

参考にしたコードと異なり、クエリとパラメータをJSON形式で渡すので、httpheaderでそれを設定しています。

実行

あるイベントについて、指定したユーザの友人の内から何名がそれに参加しているかを問い合わせます。

query <-
"START u = node:users(user_id = {uid})
MATCH u-[:FOLLOW]->f-[:ATTEND]->e
RETURN u.user_id, e.event_id, count(e) LIMIT 10"
dat      <- NULL
params   <- list(uid="3197468391")
dat[[1]] <- getQuery(query, params))
params   <- list(uid="3429017717")
dat[[2]] <- getQuery(query, params))

実行すると、

[[1]]
    u.user_id e.event_id count(e)
1  3197468391  169644382        1
2  3197468391  543972501        1
3  3197468391 3969940212        1
4  3197468391 2977769484        1
5  3197468391 1704179171        1
6  3197468391  608092517        3
7  3197468391  730958187        1
8  3197468391   19341280        1
9  3197468391  445373500        1
10 3197468391 2539029764      284

[[2]]
    u.user_id e.event_id count(e)
1  3429017717 3183605169        2
2  3429017717 2180806657       46
3  3429017717 2039358442        2
4  3429017717 2412032092        2
5  3429017717 3541811987        1
6  3429017717 2368083210        1
7  3429017717 1506378274        1
8  3429017717 1177314523        1
9  3429017717 3163090701       12
10 3429017717  266513530        1

ちゃんと結果が返ってきました。僅かですがベタ書きより良いパフォーマンスも確認できました。めでたしめでたし。