Source code for qualia2.nn.modules.module

# -*- coding: utf-8 -*- 
from ...core import *
from ...util import download_progress
from ...autograd import Tensor
from collections import OrderedDict 
from itertools import islice
import h5py as h5
import _pickle as pickle
import gzip
import os, sys
from logging import getLogger
logger = getLogger('QualiaLogger').getChild('module')

[docs]class Module(object): '''Base class for all neural network modules in qualia.\n Module can incoporate Modules, allowing to nest them in a tree structure. Note that a user-defined model must have super().__init__() in the __init__ of the model. Examples:: >>> # define a module >>> class Model(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv1 = nn.Conv2d(1, 10, kernel_size=5) >>> self.conv2 = nn.Conv2d(10, 20, kernel_size=5) >>> self.fc1 = nn.Linear(500, 50) >>> self.fc2 = nn.Linear(50, 10) >>> def forward(self, x): >>> x = F.relu(F.maxpool2d(self.conv1(x), (2,2))) >>> x = F.relu(F.maxpool2d(self.conv2(x), (2,2))) >>> x = F.reshape(x,(-1, 500)) >>> x = F.relu(self.fc1(x)) >>> x = self.fc2(x) >>> return x ''' def __init__(self): self._modules = OrderedDict() self._params = OrderedDict() self.training = True self.num_params = 0 self.input_shape = None self.output_shape = None def __repr__(self): result = '{}(\n'.format(self.__class__.__name__) if self._modules: for i, (name, module) in enumerate(self._modules.items()): result += ' [{}] {}: {}\n'.format(i, name, repr(module)) if __debug__: result += ') at 0x{:0{}X}\n'.format(id(self), 16) else: result += ')\n' return result def __str__(self): return self.__class__.__name__ def _module_summary(self): if not self._modules: logger.info('| {:20}|{:^20}|{:^20}|{:^10}|'.format(self.__class__.__name__, str(self.input_shape), str(self.output_shape), str(self.num_params))) return self.num_params else: total_params = 0 for _, module in self._modules.items(): total_params += module._module_summary() return total_params
[docs] def summary(self, input_shape): logger.info('-'*76) logger.info('{:^76}'.format('Model: ' + self.__class__.__name__)) if type(input_shape) is list: raise NotImplementedError elif type(input_shape) is tuple: x = Tensor(np.zeros(input_shape), requires_grad=False) if self._modules: logger.info('{}\n| {:20}|{:^20}|{:^20}|{:^10}|\n{}'.format('-'*76, 'layers', 'input shape', 'output shape', 'params #', '='*76)) for _, module in self._modules.items(): module.input_shape = None module.output_shape = None self.forward(x) total_params = self._module_summary() logger.info('='*76) logger.info('total params: {}'.format(total_params)) logger.info('training mode: {}'.format(self.training)) logger.info('-'*76)
def __setattr__(self, key, value): if isinstance(value, Module): self._modules[key] = value elif isinstance(value, Tensor): self._params[key] = value elif type(value) is list: if all(isinstance(n, Tensor) for n in value): self._params[key] = value else: object.__setattr__(self, key, value) def __getattr__(self, key): if self._modules: return self._modules[key] elif self._params: return self._params[key] def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs)
[docs] def add_module(self, name, module): assert isinstance(module, Module) self._modules[name] = module
[docs] def forward(self, *args, **kwargs): raise NotImplementedError
[docs] def apply(self, fn): if not self._modules: fn(self) else: for _, module in self._modules.items(): module.apply(fn)
[docs] def modules(self): if not self._modules: yield self else: for _, module in self._modules.items(): for var in module.modules(): yield var
[docs] def params(self): if self._modules: for _, module in self._modules.items(): for var in module.params(): yield var for _, var in self._params.items(): if type(var) is list: for i in var: yield i else: yield var
[docs] def zero_grad(self): if self._modules: for _, module in self._modules.items(): module.zero_grad() for _, var in self._params.items(): if type(var) is list: for i in var: i.grad = None else: var.grad = None
[docs] def eval(self): if self._modules: for _, module in self._modules.items(): module.eval() self.training = False
[docs] def train(self): if self._modules: for _, module in self._modules.items(): module.train() self.training = True
def _create_state_dict(self, name='', dtype='float64'): state_dict = {} if self._modules: for key, module in self._modules.items(): state_dict.update(module._create_state_dict(name+str(key)+'.', dtype)) for key, value in self._params.items(): if type(value) is list: for i, val in enumerate(value): state_dict[name+str(key)+'.'+str(i)] = val.data.astype(dtype) else: state_dict[name+str(key)] = value.data.astype(dtype) return state_dict
[docs] def state_dict(self, dtype='float64'): '''Returns a dictionary containing a whole state of the module.\n ''' state_dict = self._create_state_dict('', dtype) return state_dict
[docs] def load_state_dict(self, state_dict, name=''): '''Copies parameters from the state_dict into this module.\n ''' if self._modules: for key, module in self._modules.items(): module.load_state_dict(state_dict, name+str(key)+'.') for key, value in self._params.items(): if type(value) is list: for i, val in enumerate(value): self._params[key][int(i)].data = np.copy(state_dict[name+str(key)+'.'+str(i)].astype(self._params[key][int(i)].dtype)) else: self._params[key].data = np.copy(state_dict[name+str(key)].astype(self._params[key].dtype))
[docs] def load_state_dict_from_url(self, url, version=1): '''Downloads and copies parameters from the state_dict at the url into this module.\n ''' if not os.path.exists(home_dir+'/pretrained/'): os.makedirs(home_dir+'/pretrained/') model_dir = home_dir+'/pretrained' from urllib.parse import urlparse parts = urlparse(url) filename = os.path.basename(parts.path) cache = os.path.join(model_dir, filename) if not os.path.exists(cache): from urllib.request import urlretrieve urlretrieve(url, cache, reporthook=download_progress) print('\n') self.load(cache, version)
def __save__(self, h5file): if self._modules: for name, module in self._modules.items(): grp = h5file.create_group(str(name)) module.__save__(grp) for key, value in self._params.items(): if type(value) is list: grp = h5file.create_group(str(key)) for i, val in enumerate(value): grp.create_dataset(str(i), dtype='f8', data=val.asnumpy()) else: h5file.create_dataset(str(key), dtype='f8', data=value.asnumpy()) def __load__(self, h5file): if self._modules: for name, module in self._modules.items(): module.__load__(h5file[name]) for key, value in self._params.items(): if type(value) is list: for i, val in enumerate(value): self._params[key][int(i)].data = np.array(h5file[key][str(i)]) else: self._params[key].data = np.array(h5file[key])
[docs] def save(self, filename, dtype='float16', protocol=-1, version=0): '''Saves internal parameters of the Module.\n Args: filename (str): specify the filename as well as the saving path without the file extension. (ex) path/to/filename dtype (str): data type to save protocol (int): pickle protocol version (int): version for the way of saving. version 1 takes less disk space. ''' if version == 1: with gzip.open(filename+'.qla', 'wb') as f: pickle.dump(self.state_dict(dtype), f, protocol) elif version == 0: with h5.File(filename+'.hdf5', 'w') as file: self.__save__(file)
[docs] def load(self, filename, version=0): '''Loads parameters saved in HDF5 format to the Module.\n Args: filename (str): specify the filename as well as the path to the file with the file extension. (ex) path/to/filename.qla version (int): version for the way of saving. ''' if version == 1: with gzip.open(filename, 'rb') as f: self.load_state_dict(pickle.load(f)) elif version == 0: with h5.File(filename, 'r') as file: self.__load__(file)
[docs]class Sequential(Module): r'''A sequential container.\n Modules will be added to it in the order they are passed in the constructor. Examples:: >>> # model can be defiened by adding Modules >>> model = Sequential( >>> nn.Conv2d(1,20,5), >>> nn.ReLU(), >>> nn.Conv2d(20,64,5), >>> nn.ReLU() >>> ) >>> # name for each layers can also be specified >>> model = Sequential( >>> conv1 = nn.Conv2d(1,20,5), >>> relu1 = nn.ReLU(), >>> conv2 = nn.Conv2d(20,64,5), >>> relu2 = nn.ReLU() >>> ) ''' def __init__(self, *args, **kwargs): super().__init__() for i, module in enumerate(args): if isinstance(module, Module): self._modules[str(i)] = module for name, module in kwargs.items(): if isinstance(module, Module): self._modules[name] = module def __getitem__(self, slice): return Sequential(*list(islice(self._modules.values(), slice.start, slice.stop))) def __call__(self, x): return self.forward(x)
[docs] def forward(self, x): for _, module in self._modules.items(): x = module.forward(x) return x
[docs] def append(self, *arg, **kwarg): if len(arg) > 1 and len(kwarg) > 1: raise Exception('Too much arguments were given.') for module in arg: if isinstance(module, Module): self._modules[str(len(self._modules))] = module else: raise Exception('Invalid argument was given. Failed to append.') for name, module in kwarg.items(): if isinstance(module, Module): self._modules[name] = module else: raise Exception('Invalid argument was given. Failed to append.')