Source code for elektronn2.data.traindata

# -*- coding: utf-8 -*-
"""
Copyright (c) 2015 Marius Killinger, Sven Dorkenwald, Philipp Schubert
All rights reserved
"""
from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, super, zip

__all__ = ['Data', 'MNISTData', 'PianoData', 'PianoData_perc']

import logging
import os
import time
try:
    import urllib.request as urllib2
except ImportError:
    import urllib2

import numpy as np


logger = logging.getLogger('elektronn2log')

try:
    from sklearn import cross_validation
except:
    logger.warning("Can't import cross_validation from sklearn. Make sure to "
    "install it if you want to use it.")

from ..utils import pickleload

[docs]class Data(object): """ Load and prepare data, Base-Obj """ def __init__(self, n_lab=None): self._pos = 0 # self.train_d = None # self.train_l = None # self.valid_d = None # self.valid_l = None # self.test_d = None # self.test_l = None if isinstance(self.train_d, np.ndarray): self._training_count = self.train_d.shape[0] if n_lab is None: self.n_lab = np.unique(self.train_l).size else: self.n_lab = n_lab elif isinstance(self.train_d, list): self._training_count = len(self.train_d) if n_lab is None: unique = [np.unique(l) for l in self.train_l] self.n_lab = np.unique(np.hstack(unique)).size else: self.n_lab = n_lab if self.example_shape is None: self.example_shape = self.train_d[0].shape self.n_ch = self.example_shape[0] self.rng = np.random.RandomState(np.uint32((time.time()*0.0001 - int(time.time()*0.0001))*4294967295)) self.pid = os.getpid() logger.info(self.__repr__()) self._perm = self.rng.permutation(self._training_count) def _reseed(self): """Reseeds the rng if the process ID has changed!""" current_pid = os.getpid() if current_pid!=self.pid: self.pid = current_pid self.rng.seed(np.uint32((time.time()*0.0001 - int(time.time()*0.0001))*4294967295+self.pid)) logger.debug("Reseeding RNG in Process with PID: {}".format(self.pid)) def __repr__(self): return "%i-class Data Set: #training examples: %i and #validing: %i" \ %(self.n_lab, self._training_count, len(self.valid_d))
[docs] def getbatch(self, batch_size, source='train'): if source=='train': if (self._pos+batch_size) < self._training_count: self._pos += batch_size slice = self._perm[self._pos-batch_size:self._pos] else: # get new permutation self._perm = self.rng.permutation(self._training_count) self._pos = 0 slice = self._perm[:batch_size] if isinstance(self.train_d, np.ndarray): return (self.train_d[slice], self.train_l[slice]) elif isinstance(self.train_d, list): data = np.array([self.train_d[i] for i in slice]) label = np.array([self.train_l[i] for i in slice]) return (data, label) elif source=='valid': data = self.valid_d[:batch_size] label = self.valid_l[:batch_size] return (data, label) elif source=='test': data = self.test_d[:batch_size] label = self.test_l[:batch_size] return (data, label)
[docs] def createCVSplit(self, data, label, n_folds=3, use_fold=2, shuffle=False, random_state=None): cv = cross_validation.KFold(len(data), n_folds, shuffle=shuffle, random_state=random_state) # TODO: cross_validation is removed in scikit-learn>=0.20. If we can require 0.18, replace the above line with: # cv = sklearn.model_selection.KFold(n_splits=n_folds, shuffle=shuffle, random_state=random_state) # (see http://scikit-learn.org/dev/whats_new.html#model-selection-enhancements-and-api-changes) for fold, (train_i, valid_i) in enumerate(cv): if fold==use_fold: self.valid_d = data[valid_i] self.valid_l = label[valid_i] self.train_d = data[train_i] self.train_l = label[train_i]
########################################################################################## def _augmentMNIST(data, label, crop=2, factor=4): """ Creates new data, by cropping/shifting data. Control blow-up by factor and maximum offset by crop """ n = data.shape[-1] new_size = (n-crop) new_data = np.zeros((0,1,new_size,new_size), dtype=np.float32) # store new data in here new_label = np.zeros((0,), dtype=np.int16) pos = [(i%crop, int(i/crop)%crop) for i in range(crop**2)] # offests of different positions perm = np.random.permutation(range(crop**2)) for i in range(factor): # create <factor> new version of data ix =pos[perm[i]] new = (data[:, :, ix[0]:ix[0]+new_size, ix[1]:ix[1]+new_size]) new_data = np.concatenate((new_data, new), axis=0) new_label= np.concatenate((new_label, label), axis=0) return new_data, new_label
[docs]class MNISTData(Data): def __init__(self, input_node, target_node, path=None, convert2image=True, warp_on=False, shift_augment=True, center=True): if path is None: (self.train_d, self.train_l), (self.valid_d, self.valid_l), ( self.test_d, self.test_l) = self.download() else: path = os.path.expanduser(path) (self.train_d, self.train_l), (self.valid_d, self.valid_l), (self.test_d, self.test_l) = pickleload(path) self.warp_on = warp_on self.shif_augment = shift_augment self.return_flat = not convert2image self.test_l = self.test_l.astype(np.int16) self.train_l = self.train_l.astype(np.int16) self.valid_l = self.valid_l.astype(np.int16) self.example_shape = None if center: self.test_d -= self.test_d.mean() self.train_d -= self.train_d.mean() self.valid_d -= self.valid_d.mean() self.convert_to_image() if self.shif_augment: self._stripborder(1) self.train_d, self.train_l = _augmentMNIST(self.train_d, self.train_l, crop=2, factor=4) self.train_l = self.train_l[:, None] self.test_l = self.test_l[:, None] self.valid_l = self.valid_l[:, None] super(MNISTData, self).__init__() if not convert2image: self.example_shape = self.train_d[0].size logger.info("MNIST data is converted/augmented to shape {}".format(self.example_shape)) @staticmethod
[docs] def download(): if os.name == 'nt': dest = os.path.join(os.environ['APPDATA'], 'ELEKTRONN') else: dest = os.path.join(os.path.expanduser('~'), '.ELEKTRONN') if not os.path.exists(dest): os.makedirs(dest) dest = os.path.join(dest, 'mnist.pkl.gz') if os.path.exists(dest): print("Found existing mnist data") return pickleload(dest) else: print("Downloading mnist data from" "http://www.elektronn.org/downloads/mnist.pkl.gz") f = urllib2.urlopen("http://www.elektronn.org/downloads/mnist.pkl.gz") data = f.read() print("Saving data to %s" %(dest,)) with open(dest, "wb") as code: code.write(data) return pickleload(dest)
[docs] def convert_to_image(self): """For MNIST / flattened 2d, single-Layer, square images""" valid_size = self.valid_l.size test_size = self.test_l.size data = np.vstack((self.valid_d, self.test_d, self.train_d)) size = data[0].size n = int(np.sqrt(size)) assert abs(n**2-size) < 1e-6 , '<convertToImage> data is not square' count = data.shape[0] data = data.reshape((count, 1, n, n)) self.valid_d = data[:valid_size] self.test_d = data[valid_size:valid_size+test_size] self.train_d = data[valid_size+test_size:]
[docs] def getbatch(self,batch_size, source='train'): if source=='valid': ret = super(MNISTData, self).getbatch(batch_size, 'valid') if source=='test': ret = super(MNISTData, self).getbatch(batch_size, 'test') else: d, l = super(MNISTData, self).getbatch(batch_size, source) if self.warp_on: d = self._warpaugment(d) ret = d, l if self.return_flat: ret = (ret[0].reshape((batch_size, -1)), ret[1]) return ret
def _stripborder(self, pix=1): s = self.train_d.shape[-1] self.valid_d = self.valid_d[:, :, pix:s-pix, pix:s-pix] self.test_d = self.test_d [:, :, pix:s-pix, pix:s-pix] def _warpaugment(self, d, amount=1): rot_max = 5 * amount shear_max = 7 * amount scale_max = 1.15 * amount stretch_max = 0.25 * amount shear = shear_max * 2 * (np.random.rand()-0.5) twist = rot_max * 2 * (np.random.rand()-0.5) rot = 0 # min(rot_max - abs(twist), rot_max * (np.random.rand())) scale = 1 + (scale_max-1) * np.random.rand(2) stretch = stretch_max * 2 * (np.random.rand(4)-0.5) ps = (d.shape[0],)+d.shape[2:] raise ValueError("Warping is suspended, reimplement using warp.py") #w = warping.warp3dFast(d, ps, rot, shear, (scale[0], scale[1], 1), stretch, twist) return w
[docs]class PianoData(Data): def __init__(self, input_node, target_node, path='/home/mkilling/devel/data/PianoRoll/Nottingham_enc.pkl', n_tap=20, n_lab=58): path = os.path.expanduser(path) (self.train_d, self.valid_d, self.test_d) = pickleload(path) super(PianoData, self).__init__(n_lab=n_lab) self.example_shape = self.train_d[0].shape[-1] self.n_taps = n_tap self.n_lab = n_lab
[docs] def getbatch(self, batch_size, source='train'): if source=='train': if (self._pos+batch_size) < self._training_count: self._pos += batch_size slice = self._perm[self._pos-batch_size:self._pos] else: # get new permutation self._perm = self.rng.permutation(self._training_count) self._pos = 0 slice = self._perm[:batch_size] data = [self.train_d[i] for i in slice] elif source=='valid': data = self.valid_d[:batch_size] elif source=='test': data = self.test_d[:batch_size] lengths = np.array(map(len, data)) start_t = np.round(np.random.rand(batch_size)*(lengths-self.n_taps-1)).astype(np.int) x = np.array([d[t:t+self.n_taps].astype(np.float32) for d,t in zip(data, start_t)]) y = np.array([d[t+self.n_taps] for d,t in zip(data, start_t)]) return x, y
[docs]class PianoData_perc(PianoData): def __init__(self, input_node, target_node, path='/home/mkilling/devel/data/PianoRoll/Nottingham_enc.pkl', n_tap=20, n_lab=58): super(PianoData_perc, self).__init__(input_node, target_node, path='/home/mkilling/devel/data/PianoRoll/Nottingham_enc.pkl', n_tap=20, n_lab=58)
[docs] def getbatch(self, batch_size, source='train'): x, y = super(PianoData_perc, self).getbatch(batch_size, source) x = x.swapaxes(0, 1) return x.reshape((x.shape[0], -1)), y