Source code for qualia2.data.basic

# -*- coding: utf-8 -*- 
from .. import to_cpu
from ..core import *
from .dataset import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

[docs]class Spiral(Dataset): '''Spiral Dataset\n Args: num_class (int): number of classes num_data (int): number of data for each classes Shape: - data: [num_class*num_data, 2] - label: [num_class*num_data, num_class] ''' def __init__(self, num_class=3, num_data=100): super().__init__() self.num_class = num_class self.num_data = num_data self.data = np.zeros((num_data*num_class, 2)) self.label = np.zeros((num_data*num_class, num_class)) for c in range(num_class): for i in range(num_data): rate = i / num_data radius = 1.0*rate theta = c*4.0 + 4.0*rate + np.random.randn()*0.2 self.data[num_data*c+i,0] = radius*np.sin(theta) self.data[num_data*c+i,1] = radius*np.cos(theta) self.label[num_data*c+i,c] = 1
[docs] def show(self, label=None): fig, ax = plt.subplots() for c in range(self.num_class): if gpu: ax.scatter(to_cpu(self.data[(self.label[:,c]>0)][:,0]),to_cpu(self.data[(self.label[:,c]>0)][:,1])) else: ax.scatter(self.data[(self.label[:,c]>0)][:,0],self.data[(self.label[:,c]>0)][:,1]) plt.xlim(-1,1) plt.ylim(-1,1) plt.axis('off') plt.show()
[docs] def show_decision_boundary(self, model): h = 0.001 x, y = np.meshgrid(np.arange(-1, 1, h), np.arange(-1, 1, h)) out = model(Tensor(np.c_[x.ravel(), y.ravel()])) pred = np.argmax(out.data, axis=1) if gpu: plt.contourf(to_cpu(x), to_cpu(y), to_cpu(pred.reshape(x.shape))) for c in range(self.num_class): plt.scatter(to_cpu(self.data[(self.label[:,c]>0)][:,0]),to_cpu(self.data[(self.label[:,c]>0)][:,1])) else: plt.contourf(x, y, pred.reshape(x.shape)) for c in range(self.num_class): plt.scatter(self.data[(self.label[:,c]>0)][:,0],self.data[(self.label[:,c]>0)][:,1]) plt.xlim(-1,1) plt.ylim(-1,1) plt.axis('off') plt.show()
[docs]class SwissRoll(Dataset): '''Swiss roll dataset\n Args: num_class (int): number of classes num_data (int): number of data for each classes Note: num_data % num_class == 0 ''' def __init__(self, num_class=5, num_data=2000): super().__init__() assert num_data % num_class == 0 self.num_class = num_class self.num_data = num_data self.data = np.zeros((self.num_data, 3)) theta = 2*np.pi*(1+2*np.random.rand(self.num_data,1)) x = theta*np.cos(theta) y = 21*np.random.rand(self.num_data,1) z = theta * np.sin(theta) self.data = np.concatenate((x,y,z), axis=1) self.data += 0.2*np.random.randn(self.num_data,3) self.label = np.zeros((self.num_data, self.num_class)) min = np.min(theta) i = (np.max(theta) - min)/self.num_class for c in range(self.num_class): self.label[:,c][np.logical_and((min+c*i<theta[:,0]),(theta[:,0]<(min+(c+1)*i)))] = 1
[docs] def show(self, label=None): fig = plt.figure() ax = fig.add_subplot(111, projection='3d') if gpu: for c in range(self.num_class): ax.scatter(to_cpu(self.data[(self.label[:,c] > 0)][:,0]), to_cpu(self.data[(self.label[:,c] > 0)][:,1]), to_cpu(self.data[(self.label[:,c] > 0)][:,2]),) else: for c in range(self.num_class): ax.scatter(self.data[(self.label[:,c] > 0)][:,0], self.data[(self.label[:,c] > 0)][:,1], self.data[(self.label[:,c] > 0)][:,2],) plt.show()