Source code for elektronn2.tests.test_cnndata

from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, \
    super, zip

import time
import numpy as np
from matplotlib import pyplot as plt

from ..utils.plotting import scroll_plot
from ..neuromancer import Input, Input_like, Conv
from ..training.parallelisation import BackgroundProc


[docs]def test_cnndata(): print("Testing CNNData") x = Input((1, 1, 5, 40, 40), 'b,f,z,x,y') y = Conv(x, 2, (5, 31, 31), (1, 1, 1)) l = Input_like(y, dtype='int16') # data_path = os.path.expanduser('~/devel/data/BirdGT/') # (*) Path to data dir # label_path = data_path # d_files = [('raw_cube0-crop.h5', 'raw')] # l_files = [('cube0_barrier-int16.h5', 'labels')] data_init_kwargs = dict(d_path='~/lustre/mkilling/BirdGT/j0126_new/', l_path='~/lustre/mkilling/BirdGT/j0126_new/', d_files=[('raw_%i.h5' % ii, 'raw') for ii in range(3)], l_files=[('barrier_int16_%i.h5' % ii, 'lab') for ii in range(3)], n_target=2, anisotropic_data=True, data_order='xyz', valid_cubes=[0, 1], nhood_targets=True) D = CNNData(x, l, **data_init_kwargs) kwargs = dict(batch_size=1, flip=True, grey_augment_channels=[0], ret_ll_mask=True, warp_on=False, ignore_thresh=0.5) batch = D.getbatch(**kwargs) # ------------------------------------------------------------------------------ if False: # BG bg = BackgroundProc(D.getbatch, n_proc=2, target_args=(), target_kwargs=kwargs, profile=False) try: for i in range(1): data, label, seg, seg_gt, info1, info2 = bg.get() # data, label, info1, info2 = D.getbatch(**kwargs) label = label[0] seg = seg[0] seg_gt = seg_gt[0] nhood = np.eye(3, dtype=np.int32) aff_pred = label * 0.95 + 0.05 * np.random.uniform( size=label.shape) pos_counts, neg_counts = malis.get_malis_weights(aff_pred, label, seg, nhood) seg_color = seg * 1.0 / seg.max() seg_color = plt.cm.nipy_spectral(seg_color) seg_color_gt = seg_gt * 1.0 / seg_gt.max() seg_color_gt = plt.cm.nipy_spectral(seg_color_gt) aff_pred = np.transpose(aff_pred, (1, 2, 3, 0)) pos_counts = np.transpose(pos_counts, (1, 2, 3, 0)) neg_counts = np.log(neg_counts + 1) neg_counts = np.transpose(neg_counts, (1, 2, 3, 0)) plt.ioff() fig, scroller = scroll_plot.scroll_plot( [aff_pred, seg_color, seg_color_gt, neg_counts], ['PRED', 'SEG', 'SEG_GT', 'NEG']) fig.show() plt.show() except KeyboardInterrupt: pass finally: bg.shutdown() if True: off = l.shape.offsets[0] data = batch[0] label = batch[1] img = data[0, 0,] for z in range(img.shape[0]): plt.imsave('/tmp/%i-img.png' % (z), img[z,], cmap='gray') lab = label[0, 0,] for z in range(lab.shape[0]): plt.imsave('/tmp/%i-lab.png' % (z + off), lab[z].astype(np.uint8) * 255) # TIMING if False: t0 = time.clock() tt0 = time.time() for i in range(100): data, label, ll_mask1, ll_mask2 = D.getbatch(**kwargs) t1 = time.clock() tt1 = time.time() print("clock {}".format(t1 - t0)) print("time {}".format(tt1 - tt0)) fig, scroller = scroll_plot.scroll_plot( [aff_pred, seg_color, pos_counts, neg_counts], ['PRED', 'LAB', 'POS', 'NEG']) fig.show()
if __name__=="__main__": from elektronn2.data.cnndata import BatchCreatorImage from elektronn2.data.image import center_cubes from elektronn2.utils.plotting import _scroll_plot2, scroll_plot x = Input((1, 1, 8, 71, 71), 'b,f,z,x,y') y = Conv(x, 2, (1, 5, 5), (1, 1, 1)) l = Input_like(y, dtype='int16') data_init_kwargs = dict(d_path='~/lustre/mkilling/BirdGT/j0126_new/', l_path='~/lustre/mkilling/BirdGT/j0126_new/', d_files=[('raw_%i.h5' % ii, 'raw') for ii in range(3)], l_files=[('barrier_int16_%i.h5' % ii, 'lab') for ii in range(3)], aniso_factor=2.0, valid_cubes=[1, 2], target_discrete_ix=[]) data = BatchCreatorImage(x, y, **data_init_kwargs) data.train_d[0] = np.transpose(data.train_d[0], (0, 3, 2, 1)) data.train_l[0] = np.transpose(data.train_l[0], (0, 3, 2, 1)) # data.train_d[0][:,:,::20,:] *= 0.5 # data.train_d[0][:,:,:,::20] *= 0.5 # data.train_l[0][:,:,::20,:] *= 0.5 # data.train_l[0][:,:,:,::20] *= 0.5 # data.train_d[0][:,::10,:,:] *= 0.5 # data.train_l[0][:,::10,:,:] *= 0.5 sample_warp_params = dict(sample_aniso=True, lock_z=True, no_x_flip=False, warp_amount=1, perspective=True) data_batch_args = dict(grey_augment_channels=[0], ret_ll_mask=False, warp=1, warp_args=sample_warp_params, ignore_thresh=0.5, force_dense=True) batch = data.getbatch(**data_batch_args) img, target = batch[0][0, 0], batch[1][0, 0] img, target = center_cubes(img, target, crop=False) fig = _scroll_plot2([img, (0.4 * target + 0.6 * img)], ['raw', 'target']) plt.ioff() for i in range(100): batch = data.getbatch(**data_batch_args) img, target = batch[0][0, 0], batch[1][0, 0] img, target = center_cubes(img, target, crop=False) fig = _scroll_plot2([img, (0.4 * target + 0.6 * img)], ['raw', 'target']) plt.savefig("/tmp/%i.png" % i, bbox_inches='tight') plt.show() # fig2 = scroll_plot(data.train_l[0][0], 'l')