手写数字识别(mxnet官网例子)
- 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
- 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
- 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。
手写数字识别
简介:通过MNIST数据集建立一个手写数字分类器。(MNIST对于手写数据分类任务是一个广泛使用的数据集)。
1.前提:mxnet 0.10及以上、python、jupyter notebook(有时间可以jupyter notebook的用法,如:PPT的制作)
pip install requests jupyter ——python下jupyter notebook 的安装
2.加载数据集:
import mxnet as mx
mnist = mx.test_utils.get_mnist()
此时MXNET数据集已完全加载到内存中(注:此法对于大型数据集不适用)
考虑要素:快速高效地从源直接流数据+输入样本的顺序
图像通常用4维数组来表示:(batch_size,num_channels,width,height)
对于MNIST数据集,因为是28*28灰度图像,所以只有1个颜色通道,width=28,height=28,本例中batch=100(批处理100),即输入形状是(batch_size,1,28,28)
数据迭代器通过随机的调整输入来解决连续feed相同样本的问题。测试数据的顺序无关紧要。
batch_size = 100
train_iter=,mnist['train_label'], batch_size, shuffle=True)
val_iter = , mnist['test_label'], batch_size)
——初始化MNIST数据集的数据迭代器(2个:训练数据+测试数据)
3.训练+预测:(2种方法)(CNN优于MLP)
1)传统深度神经网络结构——MLP(多层神经网络)
MLP——MXNET的符号接口
为输入的数据创建一个占位符变量
data =
data =
——将数据从4维变成2维(batch_size,num_channel*width*height) fc1 = , num_hidden=128)
act1 = , act_type="relu")
——第一个全连接层及相应的激活函数
fc2 = , num_hidden = 64)
act2 = , act_type="relu")
——第二个全连接层及相应的激活函数
(声明2个全连接层,每层有128个和64个神经元)
fc3 = , num_hidden=10)
——声明大小10的最终完全连接层
mlp = , name='softmax')
——softmax的交叉熵损失
MNIST的MLP网络结构
以上,已完成了数据迭代器和神经网络的申明,下面可以进行训练。超参数:处理大小、学习速率
import logging
logging.getLogger().setLevel(logging.DEBUG) ——记录到标准输出
mlp_model = , context=mx.cpu())
——在CPU上创建一个可训练的模块
mlp_model.fit(train_iter ——训练数据
eval_data=val_iter, ——验证数据
optimizer='sgd', ——使用SGD训练
optimizer_params={'learning_rate':0.1}, ——使用
固定的学习速率
eval_metric='acc', ——训练过程中报告准确性
batch_end_callback=, 100), ——每批次100数据输
出的进展num_epoch=10) ——训练
至多通过10个数据
预测:
test_iter = , None, batch_size)
prob = mlp_model.predict(test_iter)
assert prob.shape == (10000, 10)
——计算每一个测试图像可能的预测得分(prob[i][j]第i个测试图像包含j输出类)
test_iter = , mnist['test_label'], batch_size) ——预测精度的方法
acc =
mlp_model.score(test_iter, acc)
print(acc)
assert acc.get()[1] > 0.96
如果一切顺利的话,我们将看到一个准确的值大约是0.96,这意味着我们能够准确地预测96%的测试图像中的数字。
2)卷积神经网络(CNN)
卷积层+池化层
data =
conv1 = , kernel=(5,5), num_filter=20)
tanh1 = , act_type="tanh")
pool1 = , pool_type="max", kernel=(2,2), stride=(2,2))
——第一个卷积层、池化层
conv2 = , kernel=(5,5), num_filter=50) ——第二个卷积层
tanh2 = , act_type="tanh")
pool2 = , pool_type="max", kernel=(2,2), stride=(2,2))
flatten = ——第一个全连接层
fc1 = , num_hidden=500)
tanh3 = , act_type="tanh")
fc2 = , num_hidden=10)——第二个全连接层
lenet = , name='softmax') ——Softmax损失
LeNet第一个卷积层+池化层