决策树算法核心原理
由bqw9z8tc创建,最终由small_q 被浏览 33 用户
决策树算法是一种流行的机器学习算法,用于分类和回归任务。它的核心思想是基于特征对数据集进行递归分割,从而构建出一个树形结构。每个内部节点代表一个特征的测试,每个分支代表测试的一个结果,每个叶节点代表一个类别(在分类树中)或一个具体数值(在回归树中)。
发展史
- CLS(Concept Learning System)堪称决策树算法的鼻祖。
- ID3算法,作为主流决策树算法的重要代表,凭借信息增益策略来精准选择特征并引导树的生长。
- C4.5算法在ID3的基础上更进一步,它运用信息增益率作为选择最优属性的准则,提升了决策树的性能。
- CART(分类与回归树)则是一种功能强大的二叉决策树,广泛应用于分类和回归任务,其高效性和准确性备受赞誉。
- RF(随机森林)是一种创新的组合算法,它将多个决策树并行集成,通过集体智慧来提高模型的稳定性和泛化能力。
- 而GBDT、XGBoost、LightGBM以及catboost等,作为boosting系列的杰出代表,它们能够基于树模型进行串行集成,不断提升模型的性能,是机器学习领域中的佼佼者。
构造过程
-
选择最佳分割属性:
- 决策树通过选取数据集中的最佳属性进行分割。这一步通常依赖于某种准则,如信息增益(用于ID3算法)、增益率(用于C4.5算法)或基尼不纯度(用于CART算法)。
-
分割数据:
- 根据选定的属性,将数据集分割成子集,每个子集应该在该属性上具有不同的值。
-
递归构造子树:
- 对每个子集重复上述过程,直到满足停止条件(如达到预设的树深度、子集大小小于阈值、子集在目标变量上纯度足够高等)。
核心算法
- 信息增益(ID3算法):衡量分割前后信息不确定性的减少量。
- 计算公式为:
- 信息增益=信息熵(父节点)−∑(子节点信息熵×子节点权重)
- 增益率(C4.5算法):信息增益除以属性的固有信息。减少了对多值属性的偏好。
- 基尼不纯度(CART算法):衡量数据集的不纯度。
- 计算公式为: 基尼不纯度=1−∑(*pi^*2) ;其中 pi 是第 i 类的相对频率。
决策树模型
接下来,我将生成一个简单的数据集,并展示如何用决策树对其进行分类。同时,我将创建一个决策树的图像来直观展示其结构。
上图是一个简单的二分类问题的数据集。数据集包含两个特征(Feature 1和Feature 2)和两个类别(用红色和蓝色表示)。
基于这些数据,我们构建了一个决策树模型。
上图展示了构建的决策树。这个树是这样构建的:
- 选择分割特征和阈值:决策树算法选择最佳特征和一个阈值来分割数据。在这个例子中,树的第一个分割是基于Feature 1的某个值。
- 创建内部节点和分支:每个内部节点代表一个特征的测试,每个分支代表测试的一个结果。例如,第一个节点测试Feature 1是否小于或等于某个特定值。
- 递归分割:对每个分割产生的子集重复这个过程,直到满足停止条件。在这个例子中,树的最大深度被设置为3。
- 生成叶节点:一旦达到停止条件,就会创建一个叶节点。在分类问题中,叶节点代表最终的类别预测。
决策树代码
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_classification
from sklearn.tree import DecisionTreeClassifier, plot_tree
# 生成一个简单的数据集
X, y = make_classification(n_samples=100, n_features=2, n_redundant=0, n_informative=2,
n_clusters_per_class=1, class_sep=1.5, random_state=0)
# 创建决策树模型
tree_model = DecisionTreeClassifier(max_depth=3, random_state=0)
tree_model.fit(X, y)
# 绘制数据点
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='coolwarm', edgecolor='k', s=50)
plt.title('Simple Classification Data Set')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()
# 绘制决策树
plt.figure(figsize=(12, 8))
plot_tree(tree_model, filled=True, feature_names=['Feature 1', 'Feature 2'], class_names=['Class A', 'Class B'])
plt.title('Decision Tree')
plt.show()
通过这样的结构,决策树能够清晰直观地展示出数据分类的过程。
使用场景
- 分类任务:如电子邮件垃圾过滤、客户分类、医疗诊断等。
- 回归任务:当目标变量是连续值时,决策树可以用于回归分析,如房价预测、股票价格预测等。
- 数据挖掘:由于决策树易于理解,常用于数据挖掘中,以发现数据中的模式。
这就是决策树在数据挖掘和机器学习领域广泛应用的原因之一:它不仅效果良好,而且非常容易解释和理解。