朴素贝叶斯matlab实现

  1. 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
  2. 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
  3. 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。

clc

clear

close all

data=importdata('data.txt');

wholeData=data.data;

%交叉验证选取训练集和测试集

cv=cvpartition(size(wholeData,1),'holdout',0.04);%0.04表明测试数据集占总数据集的比例

cvpartition(n,'holdout',p)创建一个随机分区,用于在n个观测值上进行保持验证。该分区将观察分为训练集和测试(或保持)集。参数p必须是标量,当0

trainData=wholeData(training(cv),:);

testData=wholeData(test(cv),:);

label=data.textdata;

attributeNumber=size(trainData,2);size(A,2):获取矩阵A的列数。attributeValueNumber=5;

%将分类标签转化为数据(因为在分类数据集中有3个类别,分别是R、B、L所以将类别转换为数字)

sampleNumber=size(label,1);

labelData=zeros(sampleNumber,1);

for i=1:sampleNumber(测试集的行数)

if label{i,1}=='R'

labelData(i,1)=1;

elseif label{i,1}=='B'

labelData(i,1)=2;

else

labelData(i,1)=3;

end

end

trainLabel=labelData(training(cv),:);

trainSampleNumber=size(trainLabel,1);

testLabel=labelData(test(cv),:);

%计算每个分类的样本的概率

labelProbability=tabulate(trainLabel);

tabulate函数的功能是创建向量X信息数据频率表。其函数使用格式:

tbl = tabulate(x)

创建的TBL(数据频率表)的结构:第一列:x的唯一值第二列:每个值的实例数量第三列:每个值的百分比

%P_yi,计算P(yi)

P_y1=labelProbability(1,3)/100;(第一行,第三个元素)

P_y2=labelProbability(2,3)/100;

P_y3=labelProbability(3,3)/100;

count_1=zeros(attributeNumber,attributeValueNumber);%count_1(i,j):y=1情况下,第i个属性取j值的数量统计

count_2=zeros(attributeNumber,attributeValueNumber);%count_1(i,j):y=2情况下,第i个属性取j值的数量统计

count_3=zeros(attributeNumber,attributeValueNumber);%count_1(i,j):y=3情况下,第i个属性取j值的数量统计

%统计每一个特征的每个取值的数量

for jj=1:3

for j=1:trainSampleNumber

for ii=1:attributeNumber

for k=1:attributeValueNumber

if jj==1

if trainLabel(j,1)==1&&trainData(j,ii)==k

count_1(ii,k)=count_1(ii,k)+1;

end

elseif jj==2

if trainLabel(j,1)==2&&trainData(j,ii)==k

count_2(ii,k)=count_2(ii,k)+1;

end

else

if trainLabel(j,1)==3&&trainData(j,ii)==k

count_3(ii,k)=count_3(ii,k)+1;

end

end

end

end

end

end

%计算第i个属性取j值的概率,P_a_y1是分类为y=1前提下取值,其他依次类推。P_a_y1=count_1./labelProbability(1,2);

P_a_y2=count_2./labelProbability(2,2);

P_a_y3=count_3./labelProbability(3,2);

%使用测试集进行数据测试

labelPredictNumber=zeros(3,1);

predictLabel=zeros(size(testData,1),1);

for kk=1:size(testData,1)

testDataTemp=testData(kk,:);

Pxy1=1;

Pxy2=1;

Pxy3=1;

%计算P(x|yi)

for iii=1:attributeNumber

Pxy1=Pxy1*P_a_y1(iii,testDataTemp(iii));

Pxy2=Pxy2*P_a_y2(iii,testDataTemp(iii));

Pxy3=Pxy3*P_a_y3(iii,testDataTemp(iii));

end

%计算P(x|yi)*P(yi)

PxyPy1=P_y1*Pxy1;

PxyPy2=P_y2*Pxy2;

PxyPy3=P_y3*Pxy3;

if PxyPy1>PxyPy2&&PxyPy1>PxyPy3

predictLabel(kk,1)=1;

disp(['this item belongs to No.',num2str(1),' label or the R labe l'])

labelPredictNumber(1,1)=labelPredictNumber(1,1)+1;

elseif PxyPy2>PxyPy1&&PxyPy2>PxyPy3

predictLabel(kk,1)=2;

labelPredictNumber(2,1)=labelPredictNumber(2,1)+1;

disp(['this item belongs to No.',num2str(2),' label or the B labe l'])

elseif PxyPy3>PxyPy2&&PxyPy3>PxyPy1

相关文档
最新文档