使用pytorch实现手写数字识别的案例。
Python(TensorFlow框架)实现手写数字识别系统的方法

Python(TensorFlow框架)实现⼿写数字识别系统的⽅法⼿写数字识别算法的设计与实现本⽂使⽤python基于TensorFlow设计⼿写数字识别算法,并编程实现GUI界⾯,构建⼿写数字识别系统。
这是本⼈的本科毕业论⽂课题,当然,这个也是机器学习的基本问题。
本博⽂不会以论⽂的形式展现,⽽是以编程实战完成机器学习项⽬的⾓度去描述。
项⽬要求:本⽂主要解决的问题是⼿写数字识别,最终要完成⼀个识别系统。
设计识别率⾼的算法,实现快速识别的系统。
1 LeNet-5模型的介绍本⽂实现⼿写数字识别,使⽤的是卷积神经⽹络,建模思想来⾃LeNet-5,如下图所⽰:这是原始的应⽤于⼿写数字识别的⽹络,我认为这也是最简单的深度⽹络。
LeNet-5不包括输⼊,⼀共7层,较低层由卷积层和最⼤池化层交替构成,更⾼层则是全连接和⾼斯连接。
LeNet-5的输⼊与BP神经⽹路的不⼀样。
这⾥假设图像是⿊⽩的,那么LeNet-5的输⼊是⼀个32*32的⼆维矩阵。
同时,输⼊与下⼀层并不是全连接的,⽽是进⾏稀疏连接。
本层每个神经元的输⼊来⾃于前⼀层神经元的局部区域(5×5),卷积核对原始图像卷积的结果加上相应的阈值,得出的结果再经过激活函数处理,输出即形成卷积层(C层)。
卷积层中的每个特征映射都各⾃共享权重和阈值,这样能⼤⼤减少训练开销。
降采样层(S层)为减少数据量同时保存有⽤信息,进⾏亚抽样。
第⼀个卷积层(C1层)由6个特征映射构成,每个特征映射是⼀个28×28的神经元阵列,其中每个神经元负责从5×5的区域通过卷积滤波器提取局部特征。
⼀般情况下,滤波器数量越多,就会得出越多的特征映射,反映越多的原始图像的特征。
本层训练参数共6×(5×5+1)=156个,每个像素点都是由上层5×5=25个像素点和1个阈值连接计算所得,共28×28×156=122304个连接。
Python实现识别手写数字Python图片读入与处理

Python实现识别⼿写数字Python图⽚读⼊与处理写在前⾯在上⼀篇⽂章中,我们已经讲过了我们想要写的全部思路,所以我们不再说全部的思路。
我这⼀次将图⽚的读⼊与处理的代码写了⼀下,和⼤纲写的过程⼀样,这⼀段代码分为以下⼏个部分:读⼊图⽚;将图⽚读取为灰度值矩阵;图⽚背景去噪;切割图⽚,得到⼿写数字的最⼩矩阵;拉伸/压缩图⽚,得到标准⼤⼩为100x100⼤⼩矩阵;将图⽚拉为1x10000⼤⼩向量,存⼊训练矩阵中。
所以下⾯将会对这⼏个函数进⾏详解。
代码分析基础内容⾸先我们现在最前⾯定义基础变量import osfrom skimage import ioimport numpy as np##Essential vavriable 基础变量#Standard size 标准⼤⼩N = 100#Gray threshold 灰度阈值color = 100/255其中标准⼤⼩指的是我们在最后经过切割、拉伸后得到的图⽚的尺⼨为NxN。
灰度阈值指的是在某个点上的灰度超过阈值后则变为1.接下来是这图像处理的⼀部分的主函数filenames = os.listdir(r"./num/")pic = GetTrainPicture(filenames)其中filenames得到在num⽬录下所有⽂件的名称组成的列表。
pic则是通过函数GetTrainPicture得到所有训练图像向量的矩阵。
这⼀篇⽂章主要就是围绕这个函数进⾏讲解。
GetTrainPicture函数GetTrainPicture函数内容如下#Read and save train picture 读取训练图⽚并保存def GetTrainPicture(files):Picture = np.zeros([len(files), N**2+1])#loop all pictures 循环所有图⽚⽂件for i, item in enumerate(files):#Read the picture and turn RGB to grey 读取这个图⽚并转为灰度值img = io.imread('./num/'+item, as_grey = True)#Clear the noise 清除噪⾳img[img>color] = 1#Cut the picture and get the picture of handwritten number#将图⽚进⾏切割,得到有⼿写数字的的图像img = CutPicture(img)#Stretch the picture and get the standard size 100x100#将图⽚进⾏拉伸,得到标准⼤⼩100x100img = StretchPicture(img).reshape(N**2)#Save the picture to the matrix 将图⽚存⼊矩阵Picture[i, 0:N**2] = img#Save picture's name to the matrix 将图⽚的名字存⼊矩阵Picture[i, N**2] = float(item[0])return Picture可以看出这个函数的信息量⾮常⼤,基本上今天做的所有步骤我都把封装到⼀个个函数⾥⾯了,所以这⾥我们可以看到图⽚处理的所有步骤都在这⾥。
python分类算法手写数字识别PPT课件

迭代得到的代价值相减,若结果小于某个阈值则立即停止迭代,此时得到最终解。
1
四、逻辑回归算法
3.多分类问题
前边讨论的都是二分类的问题,即预测结果只有两种类比:0和1,但在许多实际的问题
中,分类结果又多种可能。
这里通常采用的一种处理方式就是one vs all(一对多)的方法,对于有k个类别的数据
,我们可以把问题分割成k个二值分类问题。
式进行变换:
四、逻辑回归算法
1
1.逻辑函数
这样将θTx的取值“挤压”到[0,1]范围内,因此可以将视为分类结果取1的概率。
假设分类结果y的取值只有0和1(即负例和正例),那么在已知x情况下y取1和0的概率
分别是:
将两个式子合并一下就是:
四、逻辑回归算法
1
2.逻辑回归的梯度下降法求解
似然是在确定的结果下去推测产生该结果的可能参数,用来描述已知随机变量输出结果
时,未知参数的可能取值。关于参数θ的似然函数(在数值上)等于给定参数后变量X
的概率:
对上式两边取对数,进行化简:
四、逻辑回归算法
1
2.逻辑回归的梯度下降法求解
目标函数:
当我们令J函数导数为0时,无法求得解析解,所以需要借助迭代的方法去寻求最优解。
首先对J求导:
然后,再应用梯度下降法的迭代公式:
迭代终止的条件是将得到的参数值代入逻辑回归的损失函数中,求出代价值,与上一次
阈值都可以得到一组(FPR,TPR),以FPR作为横坐标,TPR作为纵坐标,就能够画
出ROC图。
1
二、kNN算法
1.kNN算法基础
kNN(k-NearestNeighbor,k最近邻),也称为k邻近算法,就是每个样本都可以用
Python利用逻辑回归模型解决MNIST手写数字识别问题详解

Python利⽤逻辑回归模型解决MNIST⼿写数字识别问题详解本⽂实例讲述了Python利⽤逻辑回归模型解决MNIST⼿写数字识别问题。
分享给⼤家供⼤家参考,具体如下:1、MNIST⼿写识别问题MNIST⼿写数字识别问题:输⼊⿊⽩的⼿写阿拉伯数字,通过机器学习判断输⼊的是⼏。
可以通过TensorFLow下载MNIST⼿写数据集,通过import引⼊MNIST数据集并进⾏读取,会⾃动从⽹上下载所需⽂件。
%matplotlib inlineimport tensorflow as tfimport tensorflow.examples.tutorials.mnist.input_data as input_datamnist=input_data.read_data_sets('MNIST_data/',one_hot=True)import matplotlib.pyplot as pltdef plot_image(image): #图⽚显⽰函数plt.imshow(image.reshape(28,28),cmap='binary')plt.show()print("训练集数量:",mnist.train.num_examples,"特征值组成:",mnist.train.images.shape,"标签组成:",bels.shape)batch_images,batch_labels=mnist.train.next_batch(batch_size=10) #批量读取数据print(batch_images.shape,batch_labels.shape)print('标签值:',np.argmax(bels[1000]),end=' ') #np.argmax()得到实际值print('独热编码表⽰:',bels[1000])plot_image(mnist.train.images[1000]) #显⽰数据集中第1000张图⽚输出训练集的数量有55000个,并打印特征值的shape为(55000,784),其中784代表每张图⽚由28*28个像素点组成,由于是⿊⽩图⽚,每个像素点只有⿊⽩单通道,即通过784个数可以描述⼀张图⽚的特征值。
使用生成对抗实现手写数字识别的GAN-MNIST-Python实现

使用生成对抗实现手写数字识别的GAN-MNIST-Python实现 之前为了熟悉生成对抗,在GitHub上找的一个例子,自己整理了一下。感觉效果还是比较不错的。运行过程中还将生成器的中间结果打印出来比较了一下。 开始的时候的确是什么也不像,瞎胡乱生成 很快就收敛了。可能是我的输入不是用的纯随机序列,而是有输入数字指导的,这样才能够按需生成所要的手写数字【单纯是一个随机序列,我们无法做成一个得到指定数值的手写图像,比如我想要1的图像】 大约训练到35次之后,就比较稳定了 40次训练后,能够很明确的看到很合适的手写数字图像 这些图像都不是mnist数据集中的,而是通过生成器得到的。 最后附有代码,欢迎交流 代码分割线 ------------------------------------------------------------------ 代码分割线 ------------------------------------------------------------------ 代码分割线 ------------------------------------------------------------------ ''' 生成对抗网络(Generative Adversarial Networks,GAN) 最早由 Ian Goodfellow 在 2014 年提出,是目前深度学习领域最具潜力的研究成果之一 它的核心思想是:同时训练两个相互协作、同时又相互竞争的深度神经网络(一个称为生成器 Generator,另一个称为判别器 Discriminator) 来处理无监督学习的相关问题。在训练过程中,两个网络最终都要学习如何处理任务。
本文将以深度卷积生成对抗网络(Deep Convolutional GAN,DCGAN)为例,介绍如何基于 Keras 2.0 框架,以 Tensorflow 为后端, 搭建一个真实可用的 GAN 模型,并以该模型为基础自动生成 MNIST 手写体数字 ''' ''' DCGAN on MNIST using Keras Author: Rowel Atienza
python神经网络编程之手写数字识别

python神经⽹络编程之⼿写数字识别⽬录写在之前⼀、代码框架⼆、准备⼯作三、框架的开始四、训练模型构建五、⼿写数字的识别六、源码七、思考写在之前⾸先是写在之前的⼀些建议:⾸先是关于这本书,我真的认为他是将神经⽹络⾥⾮常棒的⼀本书,但你也需要注意,如果你真的想⾃⼰动⼿去实现,那么你⼀定需要有⼀定的python基础,并且还需要有⼀些python数据科学处理能⼒然后希望⼤家在看这边博客的时候对于神经⽹络已经有⼀些了解了,知道什么是输⼊层,什么是输出层,并且明⽩他们的⼀些理论,在这篇博客中我们仅仅是展开⼀下代码;然后介绍⼀下本篇博客的环境等:语⾔:Python3.8.5环境:jupyter库⽂件: numpy | matplotlib | scipy⼀、代码框架我们即将设计⼀个神经⽹络对象,它可以帮我们去做数据的训练,以及数据的预测,所以我们将具有以下的三个⽅法:⾸先我们需要初始化这个函数,我们希望这个神经⽹络仅有三层,因为再多也不过是在隐藏层去做⽂章,所以先做⼀个简单的。
那么我们需要知道我们输⼊层、隐藏层和输出层的节点个数;训练函数,我们需要去做训练,得到我们需要的权重。
通过我们已有的权重,将给定的输⼊去做输出。
⼆、准备⼯作现在我们需要准备⼀下:1.将我们需要的库导⼊import numpy as npimport scipy.special as speimport matplotlib.pyplot as plt2.构建⼀个类class neuralnetwork:# 我们需要去初始化⼀个神经⽹络def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):passdef train(self, inputs_list, targets_list):passdef query(self, inputs_list):pass3.我们的主函数input_nodes = 784 # 输⼊层的节点数hidden_nodes = 88 # 隐藏层的节点数output_nodes = 10 # 输出层的节点数learn_rate = 0.05 # 学习率n = neuralnetwork(input_nodes, hidden_nodes, output_nodes, learn_rate)4.导⼊⽂件data_file = open("E:\sklearn_data\神经⽹络数字识别\mnist_train.csv", 'r')data_list = data_file.readlines()data_file.close()file2 = open("E:\sklearn_data\神经⽹络数字识别\mnist_test.csv")answer_data = file2.readlines()file2.close()这⾥需要介绍以下这个数据集,训练集在这⾥,测试集在这⾥三、框架的开始def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):self.inodes = inputnodes # 输⼊层节点设定self.hnodes = hiddennodes # 影藏层节点设定self.onodes = outputnodes # 输出层节点设定self.lr = learningrate # 学习率设定,这⾥可以改进的self.wih = (np.random.normal(0.0, pow(self.hnodes, -0.5),(self.hnodes, self.inodes))) # 这⾥是输⼊层与隐藏层之间的连接self.who = (np.random.normal(0.0, pow(self.onodes, -0.5),(self.onodes, self.hnodes))) # 这⾥是隐藏层与输出层之间的连接self.activation_function = lambda x: spe.expit(x) # 返回sigmoid函数Δw j,k =α∗E k ∗ sigmoid (O k )∗(1−sigmoid(O k ))⋅O j ⊤def query(self, inputs_list):inputs = np.array(inputs_list, ndmin=2).T # 输⼊进来的⼆维图像数据hidden_inputs = np.dot(self.wih, inputs) # 隐藏层计算,说⽩了就是线性代数中的矩阵的点积hidden_outputs = self.activation_function(hidden_inputs) # 将隐藏层的输出是经过sigmoid函数处理final_inputs = np.dot(self.who, hidden_outputs) # 原理同hidden_inputsfinal_outputs = self.activation_function(final_inputs) # 原理同hidden_outputsreturn final_outputs # 最终的输出结果就是我们预测的数据这⾥我们对预测这⼀部分做⼀个简单的解释:我们之前的定义输出的节点是10个,对应的是⼗个数字。
python实现手写数字识别(小白入门)
python实现⼿写数字识别(⼩⽩⼊门)⼿写数字识别(⼩⽩⼊门)今早刚刚上了节实验课,关于逻辑回归,所以⼿有点刺挠就想发个博客,作为刚刚⼊门的⼩⽩,看到代码运⾏成功就有点⼩激动,这个实验没啥含⾦量,所以路过的⼤⽜不要停留,我怕你们吐槽哈哈。
实验结果:1.数据预处理其实呢,原理很简单,就是使⽤多变量逻辑回归,将训练28*28图⽚的灰度值转换成⼀维矩阵,这就变成了求784个特征向量1个标签的逻辑回归问题。
代码如下:#数据预处理trainData = np.loadtxt(open('digits_training.csv','r'), delimiter=",",skiprows=1)#装载数据MTrain, NTrain = np.shape(trainData)#⾏列数print("训练集:",MTrain,NTrain)xTrain = trainData[:,1:NTrain]xTrain_col_avg = np.mean(xTrain, axis=0)#对各列求均值xTrain =(xTrain- xTrain_col_avg)/255#归⼀化yTrain = trainData[:,0]2.训练模型对于数学差的⼀批的我来说,学习算法真的是太太太扎⼼了,好在具体算法封装在了sklearn库中。
简单两⾏代码即可完成。
具体参数的含义随随便便⼀搜到处都是,我就不班门弄斧了,每次看见算法除了头晕啥感觉没有。
model = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=500)model.fit(xTrain, yTrain)3.测试模型,保存接下来测试⼀下模型,准确率能达到百分之90,也不算太⾼,训练数据集本来也不是很多。
python实现基于SVM手写数字识别功能
python实现基于SVM⼿写数字识别功能本⽂实例为⼤家分享了SVM⼿写数字识别功能的具体代码,供⼤家参考,具体内容如下1、SVM⼿写数字识别识别步骤:(1)样本图像的准备。
(2)图像尺⼨标准化:将图像⼤⼩都标准化为8*8⼤⼩。
(3)读取未知样本图像,提取图像特征,⽣成图像特征组。
(4)将未知测试样本图像特征组送⼊SVM进⾏测试,将测试的结果输出。
识别代码:#!/usr/bin/env pythonimport numpy as npimport mlpyimport cv2print 'loading ...'def getnumc(fn):'''返回数字特征'''fnimg = cv2.imread(fn) #读取图像img=cv2.resize(fnimg,(8,8)) #将图像⼤⼩调整为8*8alltz=[]for now_h in xrange(0,8):xtz=[]for now_w in xrange(0,8):b = img[now_h,now_w,0]g = img[now_h,now_w,1]r = img[now_h,now_w,2]btz=255-bgtz=255-grtz=255-rif btz>0 or gtz>0 or rtz>0:nowtz=1else:nowtz=0xtz.append(nowtz)alltz+=xtzreturn alltz#读取样本数字x=[]y=[]for numi in xrange(1,10):for numij in xrange(1,5):fn='nums/'+str(numi)+'-'+str(numij)+'.png'x.append(getnumc(fn))y.append(numi)x=np.array(x)y=np.array(y)svm = mlpy.LibSvm(svm_type='c_svc', kernel_type='poly',gamma=10)svm.learn(x, y)print u"训练样本测试:"print svm.pred(x)print u"未知图像测试:"for iii in xrange (1,10):testfn= 'nums/test/'+str(iii)+'-test.png'testx=[]testx.append(getnumc(testfn))printprint testfn+":",print svm.pred(testx)样本:结果:以上就是本⽂的全部内容,希望对⼤家的学习有所帮助,也希望⼤家多多⽀持。
pytorch简单的分类模型案例
pytorch简单的分类模型案例以下是一个使用PyTorch构建的简单分类模型的示例。
在这个例子中,我们将使用MNIST数据集,这是一个手写数字的数据集。
```python导入必要的库import torchimport as nnimport as optimfrom torchvision import datasets, transforms定义超参数input_size = 784 输入图像的大小 (2828)hidden_size = 100 隐藏层的大小num_classes = 10 输出类别的数量 (0-9)num_epochs = 5 训练周期的次数batch_size = 100 批处理的大小learning_rate = 学习率数据预处理:归一化transform = ()加载数据集train_dataset = (root='./data', train=True, transform=transform, download=True)test_dataset = (root='./data', train=False, transform=transform)train_loader = (dataset=train_dataset, batch_size=batch_size, shuffle=True)test_loader = (dataset=test_dataset, batch_size=batch_size, shuffle=False)定义模型结构class Net():def __init__(self, input_size, hidden_size, num_classes):super(Net, self).__init__()= (input_size, hidden_size)= ()= (hidden_size, num_classes)def forward(self, x):out = (x)out = (out)out = (out)return outmodel = Net(input_size, hidden_size, num_classes) 实例化模型对象定义损失函数和优化器criterion = ()optimizer = ((), lr=learning_rate) 使用随机梯度下降优化器训练模型for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):outputs = model(images)loss = criterion(outputs, labels)_grad() 清空之前的梯度() 反向传播,计算梯度() 更新权重参数print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, ())) 打印每个epoch的损失值```。
python实现识别手写数字python图像识别算法
python实现识别⼿写数字python图像识别算法写在前⾯这⼀段的内容可以说是最难的⼀部分之⼀了,因为是识别图像,所以涉及到的算法会相⽐之前的来说⽐较困难,所以我尽量会讲得清楚⼀点。
⽽且因为在编写的过程中,把前⾯的⼀些逻辑也修改了⼀些,将其变得更完善了,所以⼀切以本篇的为准。
当然,如果想要直接看代码,代码全部放在我的GitHub中,所以这篇⽂章主要负责讲解,如需代码请⾃⾏前往GitHub。
本次⼤纲上⼀次写到了数据库的建⽴,我们能够实时的将更新的训练图⽚存⼊CSV⽂件中。
所以这次继续往下⾛,该轮到识别图⽚的内容了。
⾸先我们需要从⽂件夹中提取出需要被识别的图⽚test.png,并且把它经过与训练图⽚相同的处理得到1x10000⼤⼩的向量。
因为两者之间存在微⼩的差异,我也不是很想再往源代码之中增加逻辑了,所以我就直接把增加待识别图⽚的函数重新写⼀个命名为GetTestPicture,内容与GetTrainPicture类似,只不过少了“增加图⽚名称”这⼀个部分。
之后我们就可以开始进⾏正式图⽚识别内容了。
主要是计算待识别图⽚与所有训练图⽚的距离。
当两个图⽚距离越近的时候,说明他们越相似,那么他们很有可能写的就是同⼀个数。
所以利⽤这个原理,我们可以找出距离待识别图像最近的⼏个训练图⽚,并输出他们的数字分别是⼏。
⽐如说我想输出前三个,前三个分别是3,3,9,那就说明这个待识别图⽚很有可能是3.之后还可以对每⼀个位置加个权重,具体的就放在下⼀次再讲,本节内容已经够多了。
(第⼀篇⽂章之中我说过利⽤图⽚洞数检测。
我尝试了⼀下,认为有些不妥,具体原因放在本⽂末。
)MAIN代码所以直接把主要代码放上来,逻辑相对来说还是⽐较清晰的import osimport OperatePicture as OPimport OperateDatabase as ODimport PictureAlgorithm as PAimport csv##Essential vavriable 基础变量#Standard size 标准⼤⼩N = 100#Gray threshold 灰度阈值color = 200/255n = 10#读取原CSV⽂件reader = list(csv.reader(open('Database.csv', encoding = 'utf-8')))#清除读取后的第⼀个空⾏del reader[0]#读取num⽬录下的所有⽂件名fileNames = os.listdir(r"./num/")#对⽐fileNames与reader,得到新增的图⽚newFileNamesnewFileNames = OD.NewFiles(fileNames, reader)print('New pictures are: ', newFileNames)#得到newFilesNames对应的矩阵pic = OP.GetTrainPicture(newFileNames)#将新增图⽚矩阵存⼊CSV中OD.SaveToCSV(pic, newFileNames)#将原数据库矩阵与新数据库矩阵合并pic = bination(reader, pic)#得到待识别图⽚testFiles = os.listdir(r"./test/")testPic = OP.GetTestPicture(testFiles)#计算每⼀个待识别图⽚的可能分类result = PA.CalculateResult(testPic, pic)for item in result:for i in range(n):print('第'+str(i+1)+'个向量为'+str(item[i+n])+',距离为'+str(item[i]))相⽐上⼀篇⽂章的内容,本篇⽂章⾥只增加了下⾯的的⼀段代码,即得到待识别图⽚名称、得到待识别图⽚向量、计算分类。
- 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
- 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
- 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。
使用pytorch实现手写数字识别的案例。
好的,下面是使用PyTorch实现手写数字识别的简单案例。
首先,导入必要的库:
```python
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
```
定义超参数:
```python
batch_size = 64
num_epochs = 10
learning_rate = 0.001
```
下载MNIST数据集并进行预处理:
```python
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
```
定义CNN模型:
```python
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2)
self.relu1 = nn.ReLU()
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
self.relu2 = nn.ReLU()
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.fc = nn.Linear(7*7*32, 10)
def forward(self, x):
out = self.conv1(x)
out = self.relu1(out)
out = self.maxpool1(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.maxpool2(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
model = CNN()
```
定义损失函数和优化器:
```python
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate) ```
训练模型:
```python
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
```
测试模型:
```python
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
```
这样就完成了一个简单的手写数字识别的PyTorch实现。