# -*- coding: utf-8 -*-
from .. import to_cpu
from ..core import *
from .dataset import *
from .transforms import Compose, ToTensor, Normalize
import matplotlib.pyplot as plt
import tarfile
import pickle
import random
[docs]class CIFAR10(Dataset):
'''CIFAR10 Dataset\n
Args:
normalize (bool): If true, the intensity value of a specific pixel in a specific image will be rescaled from [0, 255] to [0, 1]. Default: True
flatten (bool): If true, data will have a shape of [N, 3*32*32]. Default: False
Shape:
- data: [N, 3, 32, 32]
'''
def __init__(self, train=True,
transforms=Compose([ToTensor(), Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]),
target_transforms=None):
super().__init__(train, transforms, target_transforms)
def __len__(self):
if self.train:
return 50000
else:
return 10000
[docs] def state_dict(self):
return {
'label_map':{
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'
}
}
[docs] def prepare(self):
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
self._download(url, 'cifar-10.tar.gz')
if self.train:
self.data = np.empty((50000,3*32*32))
self.label = np.empty((50000,10))
for i in range(5):
self.data[i*10000:(i+1)*10000] = self._load_data(self.root+'/cifar-10.tar.gz', i+1, 'train')
self.label[i*10000:(i+1)*10000] = CIFAR10.to_one_hot(self._load_label(self.root+'/cifar-10.tar.gz', i+1, 'train'), 10)
else:
self.data = self._load_data(self.root+'/cifar-10.tar.gz', i+1, 'test')
self.label = CIFAR10.to_one_hot(self._load_label(self.root+'/cifar-10.tar.gz', i+1, 'test'), 10)
self.data = self.train_data.reshape(-1,3,32,32)
def _load_data(self, filename, idx, data_type='train'):
assert data_type in ['train', 'test']
with tarfile.open(filename, 'r:gz') as file:
for item in file.getmembers():
if ('data_batch_{}'.format(idx) in item.name and data_type == 'train') or ('test_batch' in item.name and data_type == 'test'):
data_dict = pickle.load(file.extractfile(item), encoding='bytes')
if gpu:
import numpy
data = np.asarray(data_dict[b'data'])
else:
data = data_dict[b'data']
return data
def _load_label(self, filename, idx, data_type='train'):
assert data_type in ['train', 'test']
with tarfile.open(filename, 'r:gz') as file:
for item in file.getmembers():
if ('data_batch_{}'.format(idx) in item.name and data_type == 'train') or ('test_batch' in item.name and data_type == 'test'):
data_dict = pickle.load(file.extractfile(item), encoding='bytes')
return np.array(data_dict[b'labels'])
[docs] def show(self, row=10, col=10):
H, W = 32, 32
img = np.zeros((H*row, W*col, 3))
for r in range(row):
for c in range(col):
img[r*H:(r+1)*H, c*W:(c+1)*W] = self.data[random.randint(0, len(self.data)-1)].reshape(3,H,W).transpose(1,2,0)/255
plt.imshow(to_cpu(img) if gpu else img, interpolation='nearest')
plt.axis('off')
plt.show()
[docs]class CIFAR100(Dataset):
'''CIFAR100 Dataset\n
Args:
normalize (bool): If true, the intensity value of a specific pixel in a specific image will be rescaled from [0, 255] to [0, 1]. Default: True
flatten (bool): If true, data will have a shape of [N, 3*32*32]. Default: False
label_type (str): "fine" label (the class to which it belongs) or "coarse" label (the superclass to which it belongs)
Shape:
- data: [N, 3, 32, 32]
'''
def __init__(self, train=True,
transforms=Compose([ToTensor(), Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]),
target_transforms=None,
label_type='fine'):
assert label_type in ['fine', 'coarse']
self.label_type = label_type
super().__init__(train, transforms, target_transforms)
def __len__(self):
if self.train:
return 50000
else:
return 10000
[docs] def state_dict(self):
if self.label_type == 'fine':
return {
'label_map':dict(enumerate([
'apple',
'aquarium_fish',
'baby',
'bear',
'beaver',
'bed',
'bee',
'beetle',
'bicycle',
'bottle',
'bowl',
'boy',
'bridge',
'bus',
'butterfly',
'camel',
'can',
'castle',
'caterpillar',
'cattle',
'chair',
'chimpanzee',
'clock',
'cloud',
'cockroach',
'couch',
'crab',
'crocodile',
'cup',
'dinosaur',
'dolphin',
'elephant',
'flatfish',
'forest',
'fox',
'girl',
'hamster',
'house',
'kangaroo',
'computer_keyboard',
'lamp',
'lawn_mower',
'leopard',
'lion',
'lizard',
'lobster',
'man',
'maple_tree',
'motorcycle',
'mountain',
'mouse',
'mushroom',
'oak_tree',
'orange',
'orchid',
'otter',
'palm_tree',
'pear',
'pickup_truck',
'pine_tree',
'plain',
'plate',
'poppy',
'porcupine',
'possum',
'rabbit',
'raccoon',
'ray',
'road',
'rocket',
'rose',
'sea',
'seal',
'shark',
'shrew',
'skunk',
'skyscraper',
'snail',
'snake',
'spider',
'squirrel',
'streetcar',
'sunflower',
'sweet_pepper',
'table',
'tank',
'telephone',
'television',
'tiger',
'tractor',
'train',
'trout',
'tulip',
'turtle',
'wardrobe',
'whale',
'willow_tree',
'wolf',
'woman',
'worm']))
}
elif self.label_type == 'coarse':
return {
'label_map':dict(enumerate([
'aquatic mammals',
'fish',
'flowers',
'food containers',
'fruit and vegetables',
'household electrical device',
'household furniture',
'insects',
'large carnivores',
'large man-made outdoor things',
'large natural outdoor scenes',
'large omnivores and herbivores',
'medium-sized mammals',
'non-insect invertebrates',
'people',
'reptiles',
'small mammals',
'trees',
'vehicles 1',
'vehicles 2']))
}
[docs] def prepare(self):
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
self._download(url, 'cifar-100.tar.gz')
if self.train:
self.data = self._load_data(self.root+'/cifar-100.tar.gz', 'train')
self.label = CIFAR100.to_one_hot(self._load_label(self.root+'/cifar-100.tar.gz', 'train'), 100 if self.label_type=='fine' else 20)
else:
self.data = self._load_data(self.root+'/cifar-100.tar.gz', 'test')
self.label = CIFAR100.to_one_hot(self._load_label(self.root+'/cifar-100.tar.gz', 'test'), 100 if self.label_type=='fine' else 20)
self.data = self.data.reshape(-1,3,32,32)
def _load_data(self, filename, data_type='train'):
assert data_type in ['train', 'test']
with tarfile.open(filename, 'r:gz') as file:
for item in file.getmembers():
if data_type in item.name:
data_dict = pickle.load(file.extractfile(item), encoding='bytes')
if gpu:
import numpy
data = np.asarray(data_dict[b'data'])
else:
data = data_dict[b'data']
return data
def _load_label(self, filename, data_type='train'):
assert data_type in ['train', 'test']
with tarfile.open(filename, 'r:gz') as file:
for item in file.getmembers():
if data_type in item.name:
data_dict = pickle.load(file.extractfile(item), encoding='bytes')
if self.label_type == 'fine':
return np.array(data_dict[b'fine_labels'])
elif self.label_type == 'coarse':
return np.array(data_dict[b'coarse_labels'])
[docs] def show(self, row=10, col=10):
H, W = 32, 32
img = np.zeros((H*row, W*col, 3))
for r in range(row):
for c in range(col):
img[r*H:(r+1)*H, c*W:(c+1)*W] = self.data[random.randint(0, len(self.data)-1)].reshape(3,H,W).transpose(1,2,0)/255
plt.imshow(to_cpu(img) if gpu else img, interpolation='nearest')
plt.axis('off')
plt.show()