news 2026/6/11 20:13:59

【0基础学机器学习】2.决策树

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【0基础学机器学习】2.决策树

决策树模型笔记

1. 基础知识

基本模型形式

决策树是一种常见的监督学习模型,既可以做分类,也可以做回归。它通过一系列“如果…那么…”的规则不断划分特征空间,最终在叶子节点给出预测结果。

对于分类任务,模型会根据样本特征逐层判断,例如:

  • 如果花瓣长度小于某个阈值,进入左子树
  • 否则进入右子树

最终到达某个叶子节点后,叶子节点中占比最高的类别就是预测类别。

核心目标

决策树的核心目标是:在每一次节点划分时,找到一个最优特征和最优切分点,让划分后的子节点尽可能“纯”。

分类任务中常见目标包括:

  • 让同一类别样本尽量落到同一个叶子节点
  • 降低节点的不确定性
  • 提升整体分类准确率

损失函数

决策树通常不直接写成统一的全局损失函数最小化问题,而是在每个节点上贪心地选择最优划分标准。

常见划分指标有:

  • 基尼指数(Gini Index)
  • 信息熵(Entropy)

以基尼指数为例:

Gini(D) = 1 - Σ(p_k)^2

其中p_k表示样本集合D中第k类样本所占比例。基尼指数越小,说明节点越纯。

参数求解

决策树的参数求解过程本质上是一个递归划分过程:

  1. 在当前节点中遍历候选特征
  2. 为每个特征尝试不同划分阈值
  3. 计算划分后的不纯度下降
  4. 选择收益最大的划分方式
  5. 递归生成左右子树,直到满足停止条件

常见停止条件包括:

  • 达到最大树深度
  • 节点样本数过少
  • 节点已经足够纯

应用示例(Python实现)

本项目使用scikit-learn中的DecisionTreeClassifier实现一个经典的鸢尾花三分类任务:

fromsklearn.treeimportDecisionTreeClassifier model=DecisionTreeClassifier(max_depth=3,random_state=42)model.fit(x_train,y_train)y_pred=model.predict(x_test)

注意要点

  • 决策树容易过拟合,需要通过max_depthmin_samples_split等参数控制复杂度
  • 决策树对特征缩放不敏感,一般不强制要求标准化
  • 树结构可解释性强,适合教学演示和规则分析
  • 单棵树性能通常不如集成模型,但更容易理解

2. 代码实践

model.py

model.py负责定义决策树模型、训练模型和预测接口。这里统一封装了:

  • build_model():创建模型
  • train_model():拟合训练数据
  • predict():执行预测
fromsklearn.treeimportDecisionTreeClassifierdefbuild_model(criterion:str="gini",max_depth:int=3,random_state:int=42,)->DecisionTreeClassifier:"""创建决策树分类模型。"""returnDecisionTreeClassifier(criterion=criterion,max_depth=max_depth,random_state=random_state,)deftrain_model(x_train,y_train,criterion:str="gini",max_depth:int=3,random_state:int=42,)->DecisionTreeClassifier:"""训练决策树分类模型。"""model=build_model(criterion=criterion,max_depth=max_depth,random_state=random_state,)model.fit(x_train,y_train)returnmodeldefpredict(model:DecisionTreeClassifier,x_test):"""使用训练好的模型进行预测。"""returnmodel.predict(x_test)

train.py

train.py负责训练流程,包括:

  • 训练集和测试集划分
  • 调用train_model()完成训练

代码中使用了stratify=y,保证分类任务中训练集和测试集的类别分布更加稳定。

fromsklearn.model_selectionimporttrain_test_splitfrommodelimporttrain_modeldefsplit_data(x,y,test_size:float=0.2,random_state:int=42,):"""划分训练集和测试集。"""returntrain_test_split(x,y,test_size=test_size,random_state=random_state,stratify=y,)defrun_train(x,y,test_size:float=0.2,random_state:int=42,criterion:str="gini",max_depth:int=3,):"""完成数据划分和模型训练。"""x_train,x_test,y_train,y_test=split_data(x,y,test_size=test_size,random_state=random_state,)model=train_model(x_train,y_train,criterion=criterion,max_depth=max_depth,random_state=random_state,)returnmodel,x_train,x_test,y_train,y_test

eval.py

eval.py负责评估模型效果,输出:

  • 准确率accuracy
  • 混淆矩阵confusion_matrix
  • 分类报告classification_report

这些指标能帮助我们同时观察总体表现和各类别的精确率、召回率、F1 值。

fromsklearn.metricsimportaccuracy_score,classification_report,confusion_matrixfrommodelimportpredictdefevaluate_model(model,x_test,y_test)->dict:"""评估决策树分类模型效果。"""y_pred=predict(model,x_test)return{"accuracy":accuracy_score(y_test,y_pred),"confusion_matrix":confusion_matrix(y_test,y_pred),"classification_report":classification_report(y_test,y_pred),}

dataload.py

dataload.pysklearn.datasets中加载鸢尾花数据集:

  • 特征x:4 个花萼/花瓣数值特征
  • 标签y:3 个类别标签
  • target_names:类别名称,用于可视化展示
importpandasaspdfromsklearn.datasetsimportload_irisdefload_data():"""加载 sklearn 自带的 iris 分类数据集。"""dataset=load_iris()x=pd.DataFrame(dataset.data,columns=dataset.feature_names)y=pd.Series(dataset.target,name="target")returnx,y,dataset.target_names

run.py

run.py是项目入口,负责串联整个流程:

  1. 加载数据
  2. 训练模型
  3. 评估模型
  4. 保存可视化结果

可视化部分包含:

  • 决策树结构图
  • 混淆矩阵图
frompathlibimportPathimportmatplotlib matplotlib.use("Agg")importmatplotlib.pyplotaspltfromsklearn.metricsimportConfusionMatrixDisplayfromsklearn.treeimportplot_treefromdataloadimportload_datafromevalimportevaluate_modelfrommodelimportpredictfromtrainimportrun_traindefsave_plots(model,x_test,y_test,class_names)->list[Path]:"""保存决策树结构图和混淆矩阵图。"""current_dir=Path(__file__).resolve().parent output_dir=current_dir/"figure"output_dir.mkdir(exist_ok=True)tree_path=output_dir/"decision_tree_structure.png"cm_path=output_dir/"decision_tree_confusion_matrix.png"fig,ax=plt.subplots(figsize=(16,10))plot_tree(model,feature_names=list(x_test.columns),class_names=list(class_names),filled=True,rounded=True,ax=ax,)fig.tight_layout()fig.savefig(tree_path,dpi=150,bbox_inches="tight")plt.close(fig)fig,ax=plt.subplots(figsize=(6,5))ConfusionMatrixDisplay.from_predictions(y_test,predict(model,x_test),display_labels=class_names,cmap="Blues",ax=ax,)fig.tight_layout()fig.savefig(cm_path,dpi=150,bbox_inches="tight")plt.close(fig)return[tree_path,cm_path]defmain()->None:x,y,class_names=load_data()model,x_train,x_test,y_train,y_test=run_train(x,y)metrics=evaluate_model(model,x_test,y_test)plot_paths=save_plots(model,x_test,y_test,class_names)print("Decision Tree Demo")print(f"Train size:{len(x_train)}, Test size:{len(x_test)}")print(f"Accuracy:{metrics['accuracy']:.4f}")print("Confusion Matrix:")print(metrics["confusion_matrix"])print("Classification Report:")print(metrics["classification_report"])print("Saved plots:")forplot_pathinplot_paths:print(plot_path)if__name__=="__main__":main()

运行结果

运行python run.py后,终端会输出训练集/测试集大小、准确率、混淆矩阵和分类报告。

图片会保存在当前目录下的figure/文件夹中,通常包括:

  • decision_tree_structure.png
  • decision_tree_confusion_matrix.png

如果分类结果接近满分,这是因为鸢尾花数据集本身比较经典且较容易划分,适合作为决策树入门 demo。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/11 20:08:12

SEO_掌握这些SEO核心技巧让你的排名稳步上升

SEO核心技巧:让你的排名稳步上升在互联网时代,拥有一个高排名的网站是吸引流量和提升业务的关键。如果你希望你的网站在百度搜索结果中获得更好的位置,那么掌握一些SEO核心技巧是必不可少的。本文将详细讲解几个关键的SEO技巧,帮助…

作者头像 李华
网站建设 2026/5/18 22:48:03

RemoteDebug:ESP32/ESP8266 WiFi远程调试库深度解析

1. RemoteDebug 库深度解析:面向 ESP32/ESP8266 的嵌入式 WiFi 远程调试系统RemoteDebug 是一款专为 ESP32 和 ESP8266 平台设计的轻量级、高性能远程调试库。它并非简单的Serial.print()替代品,而是一套完整的、工程化程度极高的调试基础设施&#xff0…

作者头像 李华
网站建设 2026/5/18 22:48:00

Arduino I2C LCD驱动库:PCF8574与HD44780通信详解

1. 项目概述LCD_I2C 是一款专为 Arduino 平台设计的轻量级 C 库,用于驱动基于 PCF8574 IC 扩展芯片的 162 字符型液晶显示屏。该库不依赖于 Arduino LiquidCrystal 库的底层并行接口实现,而是完全重构为面向 IC 总线通信的专用驱动架构,通过 …

作者头像 李华
网站建设 2026/5/18 22:48:03

Janus-Pro-7B开源生态与社区贡献指南

Janus-Pro-7B开源生态与社区贡献指南 如果你对Janus-Pro-7B这个模型感兴趣,并且想为它做点什么,那这篇文章就是为你准备的。开源项目就像一个热闹的集市,模型本身是集市中央最亮眼的商品,但围绕它搭建的货架、提供的工具、以及来…

作者头像 李华