读取cifar数据集

读取cifar数据集读取cifar数据集在学习深度学习的过程中必须要用到数据集对模型进行训练,本文主要介绍如何读取cifar数据集。cifar数据集的下载地址为:CIFAR官网,下载速度较慢百度云下载地址为:百度云下载地址下载好后解开压缩包。1.用python3来读取cifar文件importpickledefunpickle(self,f):fo=open(f,…

大家好,欢迎来到IT知识分享网。读取cifar数据集"

读取cifar数据集

在学习深度学习的过程中必须要用到数据集对模型进行训练,本文主要介绍如何读取cifar数据集。
cifar数据集的下载地址为:CIFAR官网,下载速度较慢
百度云下载地址为:百度云下载地址
下载好后解开压缩包。
1.用python3来读取cifar文件

   import pickle
    def unpickle(self,f):
        fo = open(f, 'rb')
        d = pickle.load(fo,encoding='latin1')
        fo.close()
        return d

此时读取的文件是一个字典,你可以查看关键字

print(d.key())

然后,根据关键字将有用的信息提取出来

data = d['data']
labels = d['labels']

读取后的data数据并不是一张张32323的图片,需要进行转换,如下。

new = data.reshape(10000,3,32,32)
#将[10000][3][32][32]转为[10000][32][32][3]
imgs = new.transpose((0,2,3,1))

此时可以作为机器学习的输入数据了。来看看图片是什么样的吧!

import matplotlib.pyplot as plt
plt.imshow(pic_test[1000])
plt.legend()
plt.show()

附上一个比较完整的代码(在python3上运行)

# -*- coding: utf-8 -*-
"""
Created on Wed Sep  4 20:21:56 2019

@author: ASUS
"""

import pickle
import numpy as np
import os
 
class Cifar10DataReader():
    def __init__(self,cifar_folder,onehot=True):
        self.cifar_folder=cifar_folder
        self.onehot=onehot
        self.data_index=1
        self.read_next=True
        self.data_label_train=None
        self.data_label_test=None
        self.batch_index=0
        
    def unpickle(self,f):
        fo = open(f, 'rb')
        d = pickle.load(fo,encoding='latin1')
        fo.close()
        return d
    
    def next_train_data(self,batch_size=100):
        assert 10000%batch_size==0,"10000%batch_size!=0"
        rdata=None
        rlabel=None
        if self.read_next:
            f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))
            #print 'read: %s'%f
            dic_train=self.unpickle(f)
            self.data_label_train=list(zip(dic_train['data'],dic_train['labels']))#label 0~9
            np.random.shuffle(self.data_label_train)
            
            self.read_next=False
            if self.data_index==5:
                self.data_index=1
            else: 
                self.data_index+=1
        
        if self.batch_index<len(self.data_label_train)//batch_size:
            #print self.batch_index
            datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
            self.batch_index+=1
            rdata,rlabel=self._decode(datum,self.onehot)
        else:
            self.batch_index=0
            self.read_next=True
            return self.next_train_data(batch_size=batch_size)
            
        return rdata,rlabel
    
    def _decode(self,datum,onehot):
        rdata=list();rlabel=list()
        if onehot:
            for d,l in datum:
                rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
                hot=np.zeros(10)
                hot[int(l)]=1
                rlabel.append(hot)
        else:
            for d,l in datum:
                rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
                rlabel.append(int(l))
        return rdata,rlabel
            
    def next_test_data(self,batch_size=100):
        if self.data_label_test is None:
            f=os.path.join(self.cifar_folder,"test_batch")
            #print 'read: %s'%f
            dic_test=self.unpickle(f)
            data=dic_test['data']
            labels=dic_test['labels']#0~9
            self.data_label_test=list(zip(data,labels))
        
        np.random.shuffle(self.data_label_test)
        datum=self.data_label_test[0:batch_size]
        
        return self._decode(datum,self.onehot)
 
if __name__=="__main__":
    dr=Cifar10DataReader(cifar_folder="./cifar-10-batches-py/")
    import matplotlib.pyplot as plt
    d,l=dr.next_test_data()
    print (np.shape(d),np.shape(l))
    plt.imshow(d[0])
    plt.show()
    for i in range(600):
        d,l=dr.next_train_data(batch_size=100)
        print (np.shape(d),np.shape(l))
 
 

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/11937.html

(0)

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信