Decision Tree In MNIST
决策树作为机器学习算法的基础算法之一,其优点是模型具有可读性,分类速度快。决策树的基本算法可以参考李航老师的《统计机器学习》,本文中的主要思路也是基于此。对于决策树而言,主要有三部分的工作需要实现,特征提取,决策树构造和决策树的剪枝。在特征提取部分,常用的是信息增益和信息增益比,分别对应于ID3和C4.5决策树算法;决策树的构造部分主要是利用递归的思想构建;本文中未涉及到决策树剪枝部分。
本文中实现的是利用信息增益进行特征提取的ID3决策树,主要需要提及的部分是关于树的节点构造的部分。树的节点中主要包含了节点类型,包括叶节点和内部节点,用于在预测过程中当预测到达叶节点时,返回所属类别;也包含了节点所属类别(只针对叶节点存在);包含当前节点的最优切分位置(对于非叶节点存在);以及切分位置取不同值所属的不同子树,用字典保存。
基本思路是:我们在函数中输入训练数据,训练标签,可切分维度集合和信息增益阈值。
首先判断当前训练标签的类数,如果只有一类,我们将其设置为叶节点,节点类别取当前标签,然后返回;然后判断可切分维度的数量,如果不存在切分维度,则取当前训练标签中的最多那一类作为节点的类别,设置节点为叶节点返回;
以上情况都不符合的情况下,我们遍历所有的切分点,找到信息增益最大的位置,提取出这一列的值;并根据这一列的不同取值所对应的标签得到下一步的训练数据、训练标签和切分维度集合用于构建子树,并将子树添加到当前树中。
- 然后不断的递归调用即可
- 其中在找到最大信息增益位置的最后需要判断和阈值的关系,在小于阈值的情况下,我们直接构建叶节点和所属类别返回。
代码如下:
|
|
对于本文中所实现的决策树在MNIST上的分类结果在Kaggle平台上测试结果为85.07%左右,相比于KNN,其分类效果要差不少,但是所需要的计算时间较少于KNN算法。