极端随机森林(Extremely Random Forest,ERF)
定义一个类来处理ERF训练。
这里利用标签编码器来对训练标签进行编码
class ERFTrainer(object):
def __init__(self, X, label_words):
self.le = preprocessing.LabelEncoder()
self.clf = ExtraTreesClassifier(n_estimators=100,max_depth=16, random_state=0)
y = self.encode_labels(label_words)
self.clf.fit(np.asarray(X), y)
# 定义一个标签分类器 进行训练
def encode_labels(self, label_words):
self.le.fit(label_words)
return np.array(self.le.transform(label_words), dtype=np.float32)
#对未知数据点进行分类
def classify(self, X):
label_nums = self.clf.predict(np.asarray(X))
label_words = self.le.inverse_transform([int(x) for x in label_nums])
return label_words
主函数
# -*- coding:utf8 -*-
import argparse
import pickle as pickle
import numpy as np
from sklearn.ensemble import ExtraTreesClassifier
from sklearn import preprocessing
- 在pycharm里编辑输入信息 方便调试
if __name__== '__main__':
feature_map_file = 'feature_map/9_8.pkl'
model_file = 'model/train_2.pkl'
# 加载9.8节输出的 feature_map
with open(feature_map_file, 'rb') as f:
feature_map = pickle.load(f)
# 提取特征和标记
label_words = [x['object_class'] for x in feature_map]
dim_size = feature_map[0]['feature_vector'].shape[1]
X = [np.reshape(x['feature_vector'], (dim_size,)) for x in feature_map]
# 训练ERF分类器 并保存模型
erf = ERFTrainer(X, label_words)
if model_file:
with open(model_file, 'wb') as f:
pickle.dump(erf, f)
- 命令行的方式
# 定义参数解析
def build_arg_parser():
parser = argparse.ArgumentParser(description='Trains the classifier')
parser.add_argument("--feature-map-file", dest="feature_map_file", required=True,
help="Input pickle file containing the feature map")
parser.add_argument("--model-file", dest="model_file", required=False,
help="Output file where the trained model will be stored")
return parser
if __name__== '__main__':
args = build_arg_parser().parse_args()
feature_map_file = args.feature_map_file
model_file = args.model_file
# Load the feature map
with open(feature_map_file, 'rb') as f:
feature_map = pickle.load(f)
# Extract feature vectors and the labels
label_words = [x['object_class'] for x in feature_map]
dim_size = feature_map[0]['feature_vector'].shape[1]
X = [np.reshape(x['feature_vector'], (dim_size,)) for x in feature_map]
# Train the Extremely Random Forests classifier
erf = ERFTrainer(X, label_words)
if args.model_file:
with open(args.model_file, 'wb') as f:
pickle.dump(erf, f)