决策树模型笔记
1. 基础知识
基本模型形式
决策树是一种常见的监督学习模型,既可以做分类,也可以做回归。它通过一系列“如果…那么…”的规则不断划分特征空间,最终在叶子节点给出预测结果。
对于分类任务,模型会根据样本特征逐层判断,例如:
- 如果花瓣长度小于某个阈值,进入左子树
- 否则进入右子树
最终到达某个叶子节点后,叶子节点中占比最高的类别就是预测类别。
核心目标
决策树的核心目标是:在每一次节点划分时,找到一个最优特征和最优切分点,让划分后的子节点尽可能“纯”。
分类任务中常见目标包括:
- 让同一类别样本尽量落到同一个叶子节点
- 降低节点的不确定性
- 提升整体分类准确率
损失函数
决策树通常不直接写成统一的全局损失函数最小化问题,而是在每个节点上贪心地选择最优划分标准。
常见划分指标有:
- 基尼指数(Gini Index)
- 信息熵(Entropy)
以基尼指数为例:
Gini(D) = 1 - Σ(p_k)^2其中p_k表示样本集合D中第k类样本所占比例。基尼指数越小,说明节点越纯。
参数求解
决策树的参数求解过程本质上是一个递归划分过程:
- 在当前节点中遍历候选特征
- 为每个特征尝试不同划分阈值
- 计算划分后的不纯度下降
- 选择收益最大的划分方式
- 递归生成左右子树,直到满足停止条件
常见停止条件包括:
- 达到最大树深度
- 节点样本数过少
- 节点已经足够纯
应用示例(Python实现)
本项目使用scikit-learn中的DecisionTreeClassifier实现一个经典的鸢尾花三分类任务:
fromsklearn.treeimportDecisionTreeClassifier model=DecisionTreeClassifier(max_depth=3,random_state=42)model.fit(x_train,y_train)y_pred=model.predict(x_test)注意要点
- 决策树容易过拟合,需要通过
max_depth、min_samples_split等参数控制复杂度 - 决策树对特征缩放不敏感,一般不强制要求标准化
- 树结构可解释性强,适合教学演示和规则分析
- 单棵树性能通常不如集成模型,但更容易理解
2. 代码实践
model.py
model.py负责定义决策树模型、训练模型和预测接口。这里统一封装了:
build_model():创建模型train_model():拟合训练数据predict():执行预测
fromsklearn.treeimportDecisionTreeClassifierdefbuild_model(criterion:str="gini",max_depth:int=3,random_state:int=42,)->DecisionTreeClassifier:"""创建决策树分类模型。"""returnDecisionTreeClassifier(criterion=criterion,max_depth=max_depth,random_state=random_state,)deftrain_model(x_train,y_train,criterion:str="gini",max_depth:int=3,random_state:int=42,)->DecisionTreeClassifier:"""训练决策树分类模型。"""model=build_model(criterion=criterion,max_depth=max_depth,random_state=random_state,)model.fit(x_train,y_train)returnmodeldefpredict(model:DecisionTreeClassifier,x_test):"""使用训练好的模型进行预测。"""returnmodel.predict(x_test)train.py
train.py负责训练流程,包括:
- 训练集和测试集划分
- 调用
train_model()完成训练
代码中使用了stratify=y,保证分类任务中训练集和测试集的类别分布更加稳定。
fromsklearn.model_selectionimporttrain_test_splitfrommodelimporttrain_modeldefsplit_data(x,y,test_size:float=0.2,random_state:int=42,):"""划分训练集和测试集。"""returntrain_test_split(x,y,test_size=test_size,random_state=random_state,stratify=y,)defrun_train(x,y,test_size:float=0.2,random_state:int=42,criterion:str="gini",max_depth:int=3,):"""完成数据划分和模型训练。"""x_train,x_test,y_train,y_test=split_data(x,y,test_size=test_size,random_state=random_state,)model=train_model(x_train,y_train,criterion=criterion,max_depth=max_depth,random_state=random_state,)returnmodel,x_train,x_test,y_train,y_testeval.py
eval.py负责评估模型效果,输出:
- 准确率
accuracy - 混淆矩阵
confusion_matrix - 分类报告
classification_report
这些指标能帮助我们同时观察总体表现和各类别的精确率、召回率、F1 值。
fromsklearn.metricsimportaccuracy_score,classification_report,confusion_matrixfrommodelimportpredictdefevaluate_model(model,x_test,y_test)->dict:"""评估决策树分类模型效果。"""y_pred=predict(model,x_test)return{"accuracy":accuracy_score(y_test,y_pred),"confusion_matrix":confusion_matrix(y_test,y_pred),"classification_report":classification_report(y_test,y_pred),}dataload.py
dataload.py从sklearn.datasets中加载鸢尾花数据集:
- 特征
x:4 个花萼/花瓣数值特征 - 标签
y:3 个类别标签 target_names:类别名称,用于可视化展示
importpandasaspdfromsklearn.datasetsimportload_irisdefload_data():"""加载 sklearn 自带的 iris 分类数据集。"""dataset=load_iris()x=pd.DataFrame(dataset.data,columns=dataset.feature_names)y=pd.Series(dataset.target,name="target")returnx,y,dataset.target_namesrun.py
run.py是项目入口,负责串联整个流程:
- 加载数据
- 训练模型
- 评估模型
- 保存可视化结果
可视化部分包含:
- 决策树结构图
- 混淆矩阵图
frompathlibimportPathimportmatplotlib matplotlib.use("Agg")importmatplotlib.pyplotaspltfromsklearn.metricsimportConfusionMatrixDisplayfromsklearn.treeimportplot_treefromdataloadimportload_datafromevalimportevaluate_modelfrommodelimportpredictfromtrainimportrun_traindefsave_plots(model,x_test,y_test,class_names)->list[Path]:"""保存决策树结构图和混淆矩阵图。"""current_dir=Path(__file__).resolve().parent output_dir=current_dir/"figure"output_dir.mkdir(exist_ok=True)tree_path=output_dir/"decision_tree_structure.png"cm_path=output_dir/"decision_tree_confusion_matrix.png"fig,ax=plt.subplots(figsize=(16,10))plot_tree(model,feature_names=list(x_test.columns),class_names=list(class_names),filled=True,rounded=True,ax=ax,)fig.tight_layout()fig.savefig(tree_path,dpi=150,bbox_inches="tight")plt.close(fig)fig,ax=plt.subplots(figsize=(6,5))ConfusionMatrixDisplay.from_predictions(y_test,predict(model,x_test),display_labels=class_names,cmap="Blues",ax=ax,)fig.tight_layout()fig.savefig(cm_path,dpi=150,bbox_inches="tight")plt.close(fig)return[tree_path,cm_path]defmain()->None:x,y,class_names=load_data()model,x_train,x_test,y_train,y_test=run_train(x,y)metrics=evaluate_model(model,x_test,y_test)plot_paths=save_plots(model,x_test,y_test,class_names)print("Decision Tree Demo")print(f"Train size:{len(x_train)}, Test size:{len(x_test)}")print(f"Accuracy:{metrics['accuracy']:.4f}")print("Confusion Matrix:")print(metrics["confusion_matrix"])print("Classification Report:")print(metrics["classification_report"])print("Saved plots:")forplot_pathinplot_paths:print(plot_path)if__name__=="__main__":main()运行结果
运行python run.py后,终端会输出训练集/测试集大小、准确率、混淆矩阵和分类报告。
图片会保存在当前目录下的figure/文件夹中,通常包括:
decision_tree_structure.pngdecision_tree_confusion_matrix.png
如果分类结果接近满分,这是因为鸢尾花数据集本身比较经典且较容易划分,适合作为决策树入门 demo。