RandomForestClassifierの中を見てみる


概要


前回sklearnのDecisionTreeClassifierのソースを見てみました。しかし現実にはDecisionTree単体で用いられることはなく、それらを集めたensembleとして使われることが一般的です。そこで、そんなensembleのうちで最も単純なRandomForestClassiferのソースを見てみたいと思います。



fitメソッド


まずはfit()メソッドを見てみます。RandomForestClassifierは継承元のさらに継承元のBaseForestクラスのfitメソッドを使います。


/sklearn/ensemble/_forest.py#L176

class BaseForest(MultiOutputMixin, BaseEnsemble, metaclass=ABCMeta):
def fit(self, X, y, sample_weight=None):

//略 パラメータチェックなど

if not self.warm_start or not hasattr(self, "estimators_"):
    # Free allocated memory, if any
    self.estimators_= []
    
 //1 作るestimatorの数
 n_more_estimators=self.n_estimators-len(self.estimators_)
 
 if n_more_estimators<0:
     raiseValueError('n_estimators=%d must be larger or equal to  
                     ''len(estimators_)=%d when warm_start==True'
                     %(self.n_estimators, len(self.estimators_)))
                     
 elif n_more_estimators==0:
     warn("Warm-start fitting without increasing n_estimators 
           does not ""fit new trees.")
           
 else:
     if self.warm_start and len(self.estimators_) >0:
     # We draw from the random state to get the random state we
     # would have got if we hadn't used a warm_start.
        random_state.randint(MAX_INT, size=len(self.estimators_))
    
    //2 estimatorを作成   
    trees=  [self._make_estimator(append=False,
             random_state=random_state)
             for i in range(n_more_estimators)]
    
    //3 それぞれ学習         
    trees=Parallel(n_jobs=self.n_jobs, 
            verbose=self.verbose,           
        **_joblib_parallel_args(prefer='threads'))      
   (delayed(_parallel_build_trees)
        (t, self, X, y, sample_weight, i, len(trees),
         verbose=self.verbose,    
         class_weight=self.class_weight,
         n_samples_bootstrap=n_samples_bootstrap)
   for i, t in enumerate(trees))
   
  # Collect newly grown trees
  self.estimators_.extend(trees)

まずパラメータのチェックなどをした後、作成するestimator(RandomForestの場合はDecisionTreeClassifier)の数を決めます。warm_startがTrueの場合は前回作成したものを再利用するため、足りない分だけ作ります(デフォルトではFalseです)。

2の部分で実際に指定の個数だけestimatorを作成しています。


そして3で各々を学習させます。_parallel_build_treesがここの木を学習させるメソッドです。

def _parallel_build_trees(tree, forest, X, y, sample_weight, 
                          tree_idx, n_trees,verbose=0,                               
                          class_weight=None,
                          n_samples_bootstrap=None):
    """    
    Private function used to fit a single tree in parallel."""
    
    if verbose>1: 
        print("building tree %d of %d"% (tree_idx+1, n_trees))
     
    //4 学習させるためのサンプルを作成   
    if forest.bootstrap:
       n_samples=X.shape[0]
       if sample_weight is None:
           curr_sample_weight=np.ones((n_samples,), 
           dtype=np.float64)     
       else:    
           curr_sample_weight=sample_weight.copy()
           
       indices=_generate_sample_indices(tree.random_state, 
               n_samples,n_samples_bootstrap)
           
       sample_counts=np.bincount(indices, minlength=n_samples)       
       curr_sample_weight*=sample_counts
       
       if class_weight=='subsample':
           with catch_warnings():
               simplefilter('ignore', DeprecationWarning)
               curr_sample_weight*=compute_sample_weight('auto',  
               y,indices=indices)
           
       elif class_weight=='balanced_subsample':
           curr_sample_weight*=compute_sample_weight('balanced', 
           y,indices=indices)
        
       //5 学習   
       tree.fit(X, y, sample_weight=curr_sample_weight, 
                check_input=False)
   else:
       tree.fit(X, y, sample_weight=sample_weight, 
               check_input=False)
       
   return tree

bootstrapパラメータがTrueの場合は(デフォルトではTrueです)与えられたサンプルをそのまま学習に使うのではなく、木毎に異なるサンプルセットを渡します。(過学習対策?)

参考

サンプルセットを作り直すのではなく、サンプル毎に重みをつけます。

if文内最初でこの計算を行った後、5で学習をします。


以上でfitメソッドは終了です。


predictメソッド


次はpredicrtメソッドを見てみます。predictメソッドについてはRandomForestClassifierの継承元のForestClassifierのものを用います。


/sklearn/ensemble/_forest.py#L176

class ForestClassifier(ClassifierMixin, BaseForest, 
                       metaclass=ABCMeta):
def predict_proba(self, X):

    all_proba= [np.zeros((X.shape[0], j), dtype=np.float64)
                for j in np.atleast_1d(self.n_classes_)]
    lock=threading.Lock()

  //6 木毎に予測
    Parallel(n_jobs=n_jobs, verbose=self.verbose,
             **_joblib_parallel_args(require="sharedmem"))
       (delayed(_accumulate_prediction)(
             e.predict_proba, X, all_proba,lock) 
             for e in self.estimators_)
    
    for proba in all_proba:
        proba/=len(self.estimators_)
        
    if len(all_proba) ==1:
        return all_proba[0]
    else:
        returnall_proba

predictを見る前にpredict内で用いられるpredict_probaメソッドを見てみます。これは各データが各カテゴリに分類される確率を計算するものです。

all_probaはデータ数×カテゴリ数の配列で、各データが各カテゴリに分類される確率(の和)を格納します。

6で木毎に予測を行います。_accumulate_predictionでは各カテゴリに分類される確率を求め、all_probaに足していきます。(コードは下記)

最後に木の数で割って規格化します。


def _accumulate_prediction(predict, X, out, lock):

    prediction=predict(X, check_input=False)
    withlock:
        if len(out) ==1:
            out[0] +=prediction
        else:for i in range(len(out)):
            out[i] +=prediction[i]


では本題のpredictメソッドを見てみましょう。


def predict(self, X):

    proba=self.predict_proba(X)
    
    if self.n_outputs_==1:
        return self.classes_.take(np.argmax(proba, axis=1), 
                                  axis=0)        
    else:
       //略

ここではpredict_probaメソッドで各カテゴリに分類される確率を求め、それが最大となる物を予測値として返します。

より多くの木が高い確率を出すとprobaの値は大きくなるので、木による多数決になっていますね。

n_outputsが2以上の場合(つまりターゲットが2次元以上の場合)はとりあえず省略します。



参考文献


sklearn/ensemble/_forest.py

sklearn/ensemble/_base.py

【機械学習】ランダムフォレストについて メモ

https://funatsu-lab.github.io/open-course-ware/machine-learning/random-forest/



最後に


ensemble学習はあくまで個々の学習機をまとめるだけなのでそこまでコードは複雑ではなかったです。

例によってコードは結構省略しているので是非ご自身でご確認ください。

最新記事

すべて表示

概要 フィッティングを行いたい場合、pythonならばscipy.optimize.leastsqなどでできます。 しかし、フィッティングを行う場合、フィッティングパラメータに条件を付けたい場合も多々あります。 例えば、下記のようにパラメータa、bは共に正の範囲で最適な値を求める、という感じです。 f(x, a, b)=a*x^2+b (a>0 and b>0) 今回はそんな手法についてご紹介しま

靴を大切にしよう!靴管理アプリ SHOES_KEEP

納品:iPhone6.5①.png

靴の履いた回数、お手入れ回数を管理するアプリです。

google-play-badge.png
Download_on_the_App_Store_Badge_JP_RGB_blk_100317.png

テーマ日記:テーマを決めてジャンルごとに記録

訂正①2040×1152.jpg

ジャンルごとにテーマ、サブテーマをつけて投稿、記録できる日記アプリです。

google-play-badge.png