AI量化知识树

决策树算法核心原理

由bqw9z8tc创建,最终由small_q 被浏览 33 用户

决策树算法是一种流行的机器学习算法,用于分类和回归任务。它的核心思想是基于特征对数据集进行递归分割,从而构建出一个树形结构。每个内部节点代表一个特征的测试,每个分支代表测试的一个结果,每个叶节点代表一个类别(在分类树中)或一个具体数值(在回归树中)。

发展史

  1. CLS(Concept Learning System)堪称决策树算法的鼻祖。
  2. ID3算法,作为主流决策树算法的重要代表,凭借信息增益策略来精准选择特征并引导树的生长。
  3. C4.5算法在ID3的基础上更进一步,它运用信息增益率作为选择最优属性的准则,提升了决策树的性能。
  4. CART(分类与回归树)则是一种功能强大的二叉决策树,广泛应用于分类和回归任务,其高效性和准确性备受赞誉。
  5. RF(随机森林)是一种创新的组合算法,它将多个决策树并行集成,通过集体智慧来提高模型的稳定性和泛化能力。
  6. 而GBDT、XGBoost、LightGBM以及catboost等,作为boosting系列的杰出代表,它们能够基于树模型进行串行集成,不断提升模型的性能,是机器学习领域中的佼佼者。

构造过程

  1. 选择最佳分割属性

    1. 决策树通过选取数据集中的最佳属性进行分割。这一步通常依赖于某种准则,如信息增益(用于ID3算法)、增益率(用于C4.5算法)或基尼不纯度(用于CART算法)。
  2. 分割数据

    1. 根据选定的属性,将数据集分割成子集,每个子集应该在该属性上具有不同的值。
  3. 递归构造子树

    1. 对每个子集重复上述过程,直到满足停止条件(如达到预设的树深度、子集大小小于阈值、子集在目标变量上纯度足够高等)。

核心算法

  • 信息增益(ID3算法):衡量分割前后信息不确定性的减少量。
    • 计算公式为:
    • 信息增益=信息熵(父节点)−∑(子节点信息熵×子节点权重)
  • 增益率(C4.5算法):信息增益除以属性的固有信息。减少了对多值属性的偏好。
  • 基尼不纯度(CART算法):衡量数据集的不纯度。
    • 计算公式为: 基尼不纯度=1−∑(*pi^*2​) ;其中 pi​ 是第 i 类的相对频率。

决策树模型

接下来,我将生成一个简单的数据集,并展示如何用决策树对其进行分类。同时,我将创建一个决策树的图像来直观展示其结构。

上图是一个简单的二分类问题的数据集。数据集包含两个特征(Feature 1和Feature 2)和两个类别(用红色和蓝色表示)。

基于这些数据,我们构建了一个决策树模型。

上图展示了构建的决策树。这个树是这样构建的:

  1. 选择分割特征和阈值:决策树算法选择最佳特征和一个阈值来分割数据。在这个例子中,树的第一个分割是基于Feature 1的某个值。
  2. 创建内部节点和分支:每个内部节点代表一个特征的测试,每个分支代表测试的一个结果。例如,第一个节点测试Feature 1是否小于或等于某个特定值。
  3. 递归分割:对每个分割产生的子集重复这个过程,直到满足停止条件。在这个例子中,树的最大深度被设置为3。
  4. 生成叶节点:一旦达到停止条件,就会创建一个叶节点。在分类问题中,叶节点代表最终的类别预测。

决策树代码

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_classification
from sklearn.tree import DecisionTreeClassifier, plot_tree

# 生成一个简单的数据集
X, y = make_classification(n_samples=100, n_features=2, n_redundant=0, n_informative=2,
                           n_clusters_per_class=1, class_sep=1.5, random_state=0)

# 创建决策树模型
tree_model = DecisionTreeClassifier(max_depth=3, random_state=0)
tree_model.fit(X, y)

# 绘制数据点
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='coolwarm', edgecolor='k', s=50)
plt.title('Simple Classification Data Set')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()

# 绘制决策树
plt.figure(figsize=(12, 8))
plot_tree(tree_model, filled=True, feature_names=['Feature 1', 'Feature 2'], class_names=['Class A', 'Class B'])
plt.title('Decision Tree')
plt.show()

通过这样的结构,决策树能够清晰直观地展示出数据分类的过程。

使用场景

  • 分类任务:如电子邮件垃圾过滤、客户分类、医疗诊断等。
  • 回归任务:当目标变量是连续值时,决策树可以用于回归分析,如房价预测、股票价格预测等。
  • 数据挖掘:由于决策树易于理解,常用于数据挖掘中,以发现数据中的模式。

这就是决策树在数据挖掘和机器学习领域广泛应用的原因之一:它不仅效果良好,而且非常容易解释和理解。

标签

决策树算法
{link}