鸢尾花分类-Tensorflow2入门
TensorFlow实现神经网络模型的一般流程
准备数据
定义模型
训练模型
评估模型
使用模型
保存模型
对我这样的新手来说,最困难的部分实际上是准备数据的过程。
利用Tensorflow2来实现简单的鸢尾花分类任务
一、准备数据
1. 导入数据
鸢尾花数据集是sklearn中的自带的数据集,只需要导入sklearn的datasets
1 |
|
datasets.load_iris()是datasets模块中的一个函数,用于加载鸢尾花数据集。它返回一个包含特征数据和目标数据的对象,通过调用返回的对象的data
属性,可以获取鸢尾花数据集的特征数据。这些特征数据通常是一个二维数组,每一行表示一个样本,每一列表示一个特征。通过调用返回的对象的target
属性,可以获取鸢尾花数据集的目标数据。目标数据通常是一个一维数组,表示每个样本所属的类别或标签。
2. 打乱数据集
官方给出的数据集是有顺序的,所以需要打乱数据集顺序,避免陷入局部最优解,可以获得更高的分类准确率。
1 |
|
其中np.random.seed(1)
用于设置随机数种子,目的是保证数据和标签被随机打乱顺序后依然是对应的
3.拆分训练数据和测试数据
接下来拆分数据集和训练集,鸢尾花数据集总共150组数据,指定前120组数据为训练集,剩余的30组数据为测试集,通常训练集大小为总数据集大小的60%到70%
1 |
|
列表截取的完整语法是[start:stop:step]
,其中:
start
:表示切片的起始位置(包含该位置的元素),默认为0,即从序列的开头开始stop
:表示切片的结束位置(不包含该位置的元素),默认为序列的长度step
:表示切片的步长(即每次取元素的间隔),默认为1
4.强制类型转换
tensorflow中一般默认数据类型为float32,但是读入的数据可能不是,所以做一下强制类型转换,保证数据的类型统一
1 |
|
5.创建TensorFlow需要的数据集
创建TensorFlow数据集对象,并且设置了每个批次的大小为32,以便在模型训练和测试过程中进行批量处理
1 |
|
二、搭建神经网络
1.创建变量
1 |
|
tf.random.truncated_normal()
函数用于生成一个指定形状的张量,其中的值是从截断正态分布中随机采样得到的,stddev=0.1
表示生成的随机值的标准差为 0.1,[4, 3]
表示张量的形状,表示一个 4x3 的矩阵,这段代码使用TensorFlow创建了两个变量 w1
和 b1
,并初始化它们的值。通过以上代码,创建了两个变量 w1
和 b1
,并将它们初始化为从截断正态分布中随机生成的值。这些变量可以在神经网络模型中用作权重和偏置,用于进行正向传播和反向传播的计算。
2.设置学习参数
1 |
|
3.训练神经网络
1 |
|