基于上次的的初阶决策树,我们这次进行一定程度上的进阶(上一次的文章//www.greatytc.com/p/c2fbed43c49b)
首先这次我们使用更加方便的Pipline的形式来实现决策树,另外再尝试用matplotlib画出一个分类图。
数据同样是使用鸢尾花的数据,为了方便画图,这次我们只挑选两个特征,进行决策树拟合。
1、导入相关库
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.datasets import load_iris
2、导入数据
x = load_iris().data
y = load_iris().target
x = x[:, :2]
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1)
这里的自变量只选取前两个特征
3、使用管道创建决策树对象,并进行拟合
model = Pipeline([
('ss', StandardScaler()),
('DTC', DecisionTreeClassifier(criterion='entropy', max_depth=3))])
model = model.fit(x_train, y_train)
y_test_hat = model.predict(x_test) # 测试数据
采用管道的流水线操作更加的便利.
管道类创建对象时可以传入一个具有两个元组的列表,其中第一个元组是数据预处理方式,第二个是训练模型,名字可以自己自由定制.
另外此次选取的两个特征值生成的模型预测准确率为80%。
4、保存决策树结果,并生成树图
export_graphviz(model.get_params('DTC')['DTC'],out_file='tree3.dot')
5、生成可视化二维图像
# 首先确定数值范围
N, M = 100, 100 # 横纵各采样多少个值
x1_min, x1_max = x[:, 0].min(), x[:, 0].max() #确定第0列的范围
x2_min, x2_max = x[:, 1].min(), x[:, 1].max() # 确定第1列的范围
t1 = np.linspace(x1_min, x1_max, N)
t2 = np.linspace(x2_min, x2_max, M)
x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点
x_show = np.stack((x1.flat, x2.flat), axis=1) # 测试点
iris_feature = ['花萼长度', '花萼宽度', '花瓣长度', '花瓣宽度']
# 画图
cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
y_show_hat = model.predict(x_show) # 预测值
y_show_hat = y_show_hat.reshape(x1.shape) # 使之与输入的形状相同
plt.figure(facecolor='w')
plt.pcolormesh(x1, x2, y_show_hat, cmap=cm_light) # 预测值的显示
plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test.ravel(), edgecolors='k', s=100, cmap=cm_dark, marker='o') # 测试数据
plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=40, cmap=cm_dark) # 全部数据
plt.xlabel(iris_feature[0], fontsize=15)
plt.ylabel(iris_feature[1], fontsize=15)
plt.xlim(x1_min, x1_max)
plt.ylim(x2_min, x2_max)
plt.grid(True)
plt.title('鸢尾花数据的决策树分类', fontsize=17)
plt.show()
由上图可知,绿色部分分类相对较为理想,蓝色和红色区域有少些数据预测并不准确,最直观的表现就行红色区域有些许蓝点,蓝色区域有些许红点。
另外,此篇文字的代码主要来自小象学院的邹博老师。