大家好,欢迎来到IT知识分享网。
读取cifar数据集
在学习深度学习的过程中必须要用到数据集对模型进行训练,本文主要介绍如何读取cifar数据集。
cifar数据集的下载地址为:CIFAR官网,下载速度较慢
百度云下载地址为:百度云下载地址
下载好后解开压缩包。
1.用python3来读取cifar文件
import pickle
def unpickle(self,f):
fo = open(f, 'rb')
d = pickle.load(fo,encoding='latin1')
fo.close()
return d
此时读取的文件是一个字典,你可以查看关键字
print(d.key())
然后,根据关键字将有用的信息提取出来
data = d['data']
labels = d['labels']
读取后的data数据并不是一张张32323的图片,需要进行转换,如下。
new = data.reshape(10000,3,32,32)
#将[10000][3][32][32]转为[10000][32][32][3]
imgs = new.transpose((0,2,3,1))
此时可以作为机器学习的输入数据了。来看看图片是什么样的吧!
import matplotlib.pyplot as plt
plt.imshow(pic_test[1000])
plt.legend()
plt.show()
附上一个比较完整的代码(在python3上运行)
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 4 20:21:56 2019
@author: ASUS
"""
import pickle
import numpy as np
import os
class Cifar10DataReader():
def __init__(self,cifar_folder,onehot=True):
self.cifar_folder=cifar_folder
self.onehot=onehot
self.data_index=1
self.read_next=True
self.data_label_train=None
self.data_label_test=None
self.batch_index=0
def unpickle(self,f):
fo = open(f, 'rb')
d = pickle.load(fo,encoding='latin1')
fo.close()
return d
def next_train_data(self,batch_size=100):
assert 10000%batch_size==0,"10000%batch_size!=0"
rdata=None
rlabel=None
if self.read_next:
f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))
#print 'read: %s'%f
dic_train=self.unpickle(f)
self.data_label_train=list(zip(dic_train['data'],dic_train['labels']))#label 0~9
np.random.shuffle(self.data_label_train)
self.read_next=False
if self.data_index==5:
self.data_index=1
else:
self.data_index+=1
if self.batch_index<len(self.data_label_train)//batch_size:
#print self.batch_index
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
rdata,rlabel=self._decode(datum,self.onehot)
else:
self.batch_index=0
self.read_next=True
return self.next_train_data(batch_size=batch_size)
return rdata,rlabel
def _decode(self,datum,onehot):
rdata=list();rlabel=list()
if onehot:
for d,l in datum:
rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
hot=np.zeros(10)
hot[int(l)]=1
rlabel.append(hot)
else:
for d,l in datum:
rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
rlabel.append(int(l))
return rdata,rlabel
def next_test_data(self,batch_size=100):
if self.data_label_test is None:
f=os.path.join(self.cifar_folder,"test_batch")
#print 'read: %s'%f
dic_test=self.unpickle(f)
data=dic_test['data']
labels=dic_test['labels']#0~9
self.data_label_test=list(zip(data,labels))
np.random.shuffle(self.data_label_test)
datum=self.data_label_test[0:batch_size]
return self._decode(datum,self.onehot)
if __name__=="__main__":
dr=Cifar10DataReader(cifar_folder="./cifar-10-batches-py/")
import matplotlib.pyplot as plt
d,l=dr.next_test_data()
print (np.shape(d),np.shape(l))
plt.imshow(d[0])
plt.show()
for i in range(600):
d,l=dr.next_train_data(batch_size=100)
print (np.shape(d),np.shape(l))
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/11937.html