数据小站
数据科学成长之路

sklearn中DecisionTree决策树模型使用-初体验(一)

sklearn决策树模块中,DecisionTreeClassifier分类器的基本使用体验。

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
from matplotlib import pyplot as plt

if __name__ == '__main__':
    #加载经典的鸢尾花数据集
    iris = datasets.load_iris()
    data = iris.data
    target = iris.target
    feature_names = iris.feature_names
    target_names = iris.target_names

    #为了方便可视化,取了2个维度的特征,实际应用中特征维度会很多
    X = data[:][:,:2]

    #创建决策树分类器对象,并训练,预测三步走
    clf = DecisionTreeClassifier()
    clf.fit(X,target)
    y = clf.predict(X)

    #可视化,查看预测和真实结果的差异
    colors =['g','r','b']
    fig,axes = plt.subplots(2,1,sharex=True,sharey=True)
    for i in range(3):
        x1 = X[target==i]
        axes[0].scatter(x1[:,0],x1[:,1],s=15 ,c=colors[i])
    for i in range(3):
        x1 = X[y==i]
        axes[1].scatter(x1[:,0],x1[:,1],s=15 ,c=colors[i])
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.xlabel(feature_names[0])
    plt.ylabel(feature_names[1])
    plt.show()

skearn中决策树分类器,DecisionTreeClassifier的应用,和大部分分类器步骤一样,三个步骤,①实例化一个分类器对象 clf = DecisionTreeClassifier(),②训练分类器 clf.fit(X,target),③应用预测 y = clf.predict(X)。

本次案例,采用sklearn自带的经典数据集–鸢尾花数据集,为了方便可视化查看效果,用了2个特征训练决策树。

sklearn中DecisionTreeClassifier对训练数据的格式要求:clf.fit(X,target)

官方文档中对输入数据 ”X“ 的格式要求:

X : {array-like, sparse matrix} of shape (n_samples, n_features)
The training input samples. Internally, it will be converted to
``dtype=np.float32`` and if a sparse matrix is provided
to a sparse ``csc_matrix``.

决策树算法本身,是可以直接对数值型、标称型数据做训练。

但是在sklearn实现的决策树分类器上,对输入数据有要求:输入的训练集,数据类型为浮点型数据。如果数据集中有类别特征,需要在特征工程阶段,将文本类别变量,转化成数值型,常用one-hot编码。

决策树的预测结果

本次案例,直接将鸢尾花所有的数据当做数据集,输入给决策树模型,用于训练。

模型训练好后,又直接将数据集,喂给模型,用于预测。

图一中,为二维特征下的原始特征分布, 三种颜色分别对应三种类别。

图二中,为通过决策树模型预测后,二维特征对应的预测结果集分类。

红框中,圈出了部分预测的结果和实际的结果不一致的点。

在sklearn中调用决策树模型,本身很便利,核心的调取模块只有三行,实例分离器,训练,预测:

clf = DecisionTreeClassifier()

clf.fit(X,target)

y = clf.predict(X)

决策树边界可视化

决策树的2D可视化,决策边界绘制,参考pcolormesh() 区域编辑绘制使用方法

决策树的保存

tree.export_graphviz(clf,out_file='clf')

tree模块下export_graphviz() 可以将决策树保存到本地,可以通过本地txt文件打开查看。 直接生成的文件可读性很差,可以借助graphviz 工具生成树的结构。

赞(0) 打赏
未经允许不得转载:技术文档分享 » sklearn中DecisionTree决策树模型使用-初体验(一)

评论 抢沙发