Source code for elektronn2.training.trainer

# -*- coding: utf-8 -*-
# ELEKTRONN2 Toolkit
# Copyright (c) 2015 Marius F. Killinger
# All rights reserved

from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, super, zip


import logging
import os
import sys
import time
import getpass
import traceback
from collections import OrderedDict

import numpy as np
import scipy.misc as misc
import matplotlib.pyplot as plt

import theano

from ..neuromancer.model import modelload, rebuild_model
from ..neuromancer.loss import SquaredLoss
from .. import utils
from ..config import config
from ..utils.plotting import plot_trainingtarget, my_quiver
from ..data import transformations
from . import trainutils
from .parallelisation import BackgroundProc
from .trainutils import HistoryTracker, Schedule

floatX = theano.config.floatX

logger = logging.getLogger('elektronn2log')
inspection_logger = logging.getLogger('elektronn2log-inspection')
user_name = getpass.getuser()
__all__ = ['Trainer', 'TracingTrainer', 'TracingTrainerRNN']


NLL_TEXT = "The NN diverged to `nan` Loss!!!\n "\
           "You have the chance to inspect the last used examples and the "\
           "internal state of the pipeline in the command line. The last "\
           "presented training input data is `batch[0]` and the "\
           "corresponding target `batch[1]`"\

[docs]class Trainer(object): def __init__(self, exp_config): self.exp_config = exp_config self.schedules = OrderedDict() self.model = self._create_model() self.data = self._load_data() self.batch_size = self._infer_batch_size() self.tracker = HistoryTracker() self.tracker.register_debug_output_names(self.model.debug_output_names) self.preview_data = self._load_preview_data() self.saved_raw_preview = False self.get_batch_kwargs = self.exp_config.data_batch_args self.get_batch_kwargs['batch_size'] = self.batch_size self.get_batch_kwargs['source'] = 'train' if self.exp_config.monitor_batch_size%self.batch_size!=0: bs = int(np.ceil(float(self.exp_config.monitor_batch_size)/self.batch_size)) self.exp_config.monitor_batch_size = bs os.chdir(self.exp_config.save_path) # The trainer works directly in the save dir #self.debug_shit = [] def _create_model(self): if self.exp_config.create_model: if self.exp_config.model_load_args: mdl = self.exp_config.create_model( self.exp_config.model_load_args) else: mdl = self.exp_config.create_model() else: mdl = modelload(self.exp_config.model_load_path, **self.exp_config.model_load_args) mdl.set_opt_meta_params(self.exp_config.optimiser, self.exp_config.optimiser_params) for var, params in self.exp_config.schedules.items(): if params: schedule = Schedule(**params) try: schedule.bind_variable(obj=mdl, prop_name=var) except: logger.debug("%s not found in model, trying config now" %(var)) schedule.bind_variable(obj=self.exp_config, prop_name=var) self.schedules[var] = schedule logger.info(schedule) return mdl def _load_data(self): if isinstance(self.exp_config.data_class, (list, tuple)): mod, cls = self.exp_config.data_class cls = utils.import_variable_from_file(mod, cls) else: from .. import data cls = getattr(data, self.exp_config.data_class) return cls(self.model.input_node, self.model.target_node, **self.exp_config.data_init_kwargs) def _load_preview_data(self): if self.exp_config.preview_data_path is not None: data = utils.h5load(self.exp_config.preview_data_path) if not (isinstance(data, list) or isinstance(data, (tuple, list))): data = [data,] data = [d.astype(floatX)/d.max() for d in data] return data else: return None def _infer_batch_size(self): model_bs = self.model.batch_size conf_bs = self.exp_config.batch_size if model_bs: if conf_bs: assert model_bs==conf_bs, "Conflicting batchsizes from " \ "model (%d) and experiment " \ "configuration (%d)" % (model_bs, conf_bs) return model_bs elif conf_bs: return conf_bs
[docs] def run(self): exp_config = self.exp_config save_name = exp_config.save_name data = self.data t_passed = 0 t_pt = 2 t_pi = 2 last_save_t = 0 last_save_t2= 0 save_time = config.param_save_h save_time2 = config.initial_prev_h loss, loss_smooth, train_loss, valid_loss, train_error, valid_error, param_vars = 0, 0, 0, 0, 0, 0, 0 user_termination = False if isinstance(self.model.loss_node.parent[0], SquaredLoss): is_regression = True else: is_regression = False pp_err = 'err' if is_regression else '%' # -------------------------------------------------------------------------------------------------------- if config.background_processes: n_proc = max(2, int(config.background_processes)) bg_worker = BackgroundProc(data.getbatch, n_proc=n_proc, target_kwargs=self.get_batch_kwargs) # -------------------------------------------------------------------------------------------------------- try: i = -1 t0 = time.time() while i < exp_config.n_steps: try: if config.background_processes: batch = bg_worker.get() else: batch = data.getbatch(**self.get_batch_kwargs) if exp_config.class_weights is not None: batch = batch + (exp_config.class_weights,) #self.debug_shit.append(batch[1]) #----------------------------------------------------------------------------------------------------- loss, t_per_train, debug_outputs = self.model.trainingstep(*batch, optimiser=exp_config.optimiser) # Update step i += 1 #----------------------------------------------------------------------------------------------------- t_per_it = time.time() - t0 t_passed += t_per_it t0 = time.time() t_pi = 0.8*t_pi + 0.2*t_per_it # EMA loss_smooth = self.model.loss_smooth # check for divergence if np.any(np.isnan(loss)) or np.any(np.isinf(loss)): logger.warning(NLL_TEXT) raise KeyboardInterrupt #self.self.model.optimisers[exp_config.optimiser].repair_fuckup() if len(batch) == 1: batch_char = 0 else: batch_char = batch[1][:,0].mean() # assuming targets have shape (b,f,...) self.tracker.update_timeline([t_passed, loss, batch_char]) if debug_outputs: self.tracker.update_debug_outputs([i, loss,]+debug_outputs) # Save Parameters # if (t_passed-last_save_t)/3600 > config.param_save_h: # last_save_t = t_passed # time_string = '-'+str(save_time)+'h' # self.model.save(os.path.join('Backup', save_name+time_string+'.mdl')) # save_time += config.param_save_h if i%config.param_save_it==0 and i>0: it_string = '-'+str(i//1000)+'k' self.model.save(os.path.join('Backup', save_name+it_string+'.mdl')) # Create preview prediction images if self.preview_data is not None: if (t_passed-last_save_t2)/3600 > config.prev_save_h \ or (t_passed/3600 > config.initial_prev_h and last_save_t2==0): # first time last_save_t2 = t_passed exp_config.preview_kwargs['number'] = save_time2 save_time2 += config.prev_save_h try: self.preview_slice(**exp_config.preview_kwargs) except: logger.warning("Preview Predictions failed." "Are the preview raw data in " "the correct format?") # reset time because we only count training time # not time spent for previews (making previews # is not a computational payload of the actual # training but just for "fun") t0 = time.time() # Adjust the learning rate and other schedule parameters for schedule in self.schedules.values(): if i==schedule.next_update: schedule.update(i) if (i%exp_config.history_freq==0) and exp_config.history_freq!=0: lr = self.model.lr mom = self.model.mom if len(self.model.gradnet_rates): gradnetrate = np.mean(self.model.gradnet_rates) else: gradnetrate = 0 ### Training & Valid Errors ### loss_after = self.model.loss(*batch) loss_gain = loss_after - loss train_loss, train_error = self.test_model('train') valid_loss, valid_error = self.test_model('valid') if not is_regression: train_error *= 100 valid_error *= 100 self.tracker.update_history([i, t_passed, train_loss, valid_loss, loss_gain, train_error, valid_error, lr, mom, gradnetrate]) ### Plotting / Saving ### self.model.save(save_name+'-LAST.mdl') self.tracker.save(os.path.join('Backup', save_name)) if config.plot_on and ((i>=exp_config.history_freq*3) or i>60): self.tracker.plot(save_name) if config.print_status: t = utils.pretty_string_time(t_passed) out = "%05i L_m=%.3f, L=%.2f, tr=%05.2f%s, "%(i, loss_smooth, loss, train_error, pp_err) out +="vl=%05.2f%s, prev=%04.1f, L_diff=%+.1e, "\ %(valid_error, pp_err, batch_char*100, loss_gain) out +="LR=%.5f, %.2f it/s, %s" %(lr, 1.0/t_pi, t) logger.info(out) # User Interface ############################################## except (KeyboardInterrupt, ValueError, TypeError) as e: if not isinstance(e, KeyboardInterrupt): traceback.print_exc() print("\nEntering Command line such that Exception can be " "further inspected by user.\n\n") out = "%05i L_m=%.5f, L=%.4f, train=%.5f, valid=%.5f, " %(i, loss_smooth, loss, train_loss, valid_loss) out +="train=%.3f%s, valid=%.3f%s,\n LR=%.6f, MOM=%.6f, "\ %(train_error, pp_err, valid_error, pp_err, self.model.lr, self.model.mom) out +="%.1f GPU-it/s, %.1f CPU-it/s, " %(1.0/self.model.time_per_step, 1.0/t_pi) t = utils.pretty_string_time(t_passed) logger.info(out+t) # Like a command line, but cannot change singletons var_push = globals() var_push.update(locals()) ret = trainutils.user_input(var_push) if ret=='kill': user_termination = True if config.background_processes: bg_worker.reset() plt.close('all') # reset time after user interaction, otherwise time # will appear as pause in plot t0 = time.time() # End UI ################################################## # This is in the epoch/UI loop if (t_passed > exp_config.max_runtime) or user_termination : logger.info('Timeout or manual Termination') break # This is OUTSIDE the training loop i.e. # the last block of the function ``run`` self.model.save(save_name+"-FINAL.mdl") if len(self.tracker.timeline) > 10: self.tracker.plot(save_name) logger.info('End of Training') logger.info('#'*60 + '\n' + '#'*60 + '\n') # -------------------end of run()---------------------------------- except: sys.excepthook(*sys.exc_info()) # show info on error finally: if config.background_processes: bg_worker.shutdown() if self.model.batch_normalisation_active: print("Rebuilding model, replacing batch normalisation layers " "with constant values") self.model = rebuild_model(self.model, replace_bn='const')
[docs] def test_model(self, data_source): """ Computes Loss and error/accuracy on batch with ``monitor_batch_size`` Parameters ---------- data_source: string 'train' or 'valid' Returns ------- Loss, error: """ # copy because it is modified in next line! kwargs = dict(self.get_batch_kwargs) kwargs['source']=data_source kwargs['batch_size'] = self.exp_config.monitor_batch_size try: batch = self.data.getbatch(**kwargs) except ValueError: logger.warning("Test model, getbatch failed. No validation data?") return np.nan, np.nan # 0, 0 y_aux = [] if batch[1] is None: return 0, 0 if self.exp_config.class_weights is not None: y_aux.append(self.exp_config.class_weights) rates = self.model.dropout_rates self.model.dropout_rates = ([0.0,]*len(rates)) batch_axis = self.model.input_node.shape.tag2index('b') n = batch[0].shape[batch_axis] loss = 0 error = 0 for j in range(int(np.ceil(np.float(n)/self.batch_size))): slice_obj = [slice(None) for i in range(batch_axis+1)] slice_obj[batch_axis] = slice(j*self.batch_size, (j+1)*self.batch_size) d = batch[0][slice_obj] # data l = batch[1][slice_obj] # target if len(batch) > 2: aux = [] for b in batch[2:]: aux.append(b[j*self.batch_size:(j+1)*self.batch_size]) nl, er, pred = self.model.predict_ext(d, l, *(aux + y_aux)) else: nl, er, pred = self.model.predict_ext(d, l, *y_aux) nb_samples = d.shape[batch_axis] loss += nl*nb_samples error += er*nb_samples loss /= n error /= n self.model.dropout_rates = rates # restore old rates return loss, error
[docs] def debug_getcnnbatch(self): """ Executes ``getbatch`` but with un-strided labels and always returning info. The first batch example is plotted and the whole batch is returned for inspection. """ if self.model.ndim>=2: kwargs = dict(self.get_batch_kwargs) kwargs['force_dense'] = True batch = self.data.getbatch(**kwargs) data, target = batch[0][0], batch[1][0] target[np.isclose(target, -666)] = 0 if self.model.ndim==2: if target.shape[0] >= 3: target = np.transpose(target, (1,2,0))[:,:,:3] target = target[...,[2,1,0]] else: target = target[0] plot_trainingtarget(data[0], target, 1) else: t_i = target.shape[1]//2 if target.shape[0] >= 3: target = np.transpose(target, (1,2,3,0))[0,:,:,:3] target = target[...,[2, 1, 0]] else: target = target[0,t_i] i = self.data.offsets[0] # z offset plot_trainingtarget(data[0,i+t_i], target, 1) plt.ion() plt.show() plt.savefig('Batch_test_image.png', bbox_inches='tight') plt.pause(0.01) plt.pause(2.0) plt.close('all') plt.pause(0.01) return batch else: logger.warning("This function is only available for 'img-img' training mode")
[docs] def predict_and_write(self, pred_node, raw_img, number=0, export_class='all', block_name='', z_thick=5): """ Predict and and save a slice as preview image Parameters ---------- raw_img : np.ndarray raw data in the format (ch, x, y, z) number: int/float consecutive number for the save name (i.e. hours, iterations etc.) export_class: str or int 'all' writes images of all classes, otherwise only the class with index ``export_class`` (int) is saved. block_name: str Name/number to distinguish different raw_imges """ block_name = str(block_name) pred = pred_node.predict_dense(raw_img) # returns (k, (z,) y, x) z_sh = pred.shape[1] pred = pred[:,(z_sh-z_thick)//2:(z_sh-z_thick)//2+z_thick,:,:,] save_name = self.exp_config.save_name for z in range(pred.shape[1]): if export_class=='all': for c in range(pred.shape[0]): plt.imsave('%s-pred-%s-z%i-c%i-%shrs.png' \ %(save_name, block_name, z, c, number), pred[c,z,:,:], cmap='gray') elif export_class in ['malis', 'affinity']: plt.imsave('%s-pred-%s-aff-z%i-%shrs.png' \ %(save_name, block_name, z, number), np.transpose(pred[0:6:2,z,:,:],(1,2,0)), cmap='gray') else: if isinstance(export_class, (list, tuple)): for c in export_class: plt.imsave('%s-pred-%s-z%i-c%i-%shrs.png' \ %(save_name, block_name, z, c, number), pred[c,z,:,:], cmap='gray') else: c = int(export_class) plt.imsave('%s-pred-%s-z%i-c%i-%shrs.png' \ %(save_name, block_name, z, c, number), pred[c,z,:,:], cmap='gray') if not self.saved_raw_preview: # only do once if len(pred_node.shape.offsets)==2: z_off = 0 else: z_off = int(pred_node.shape.offsets[0]) for z in range(pred.shape[1]): plt.imsave('%s-raw-%s-z%i.png'%(save_name, block_name, z), raw_img[0,z+z_off,:,:], cmap='gray')
[docs] def preview_slice_from_traindata(self, cube_i=0, off=(0,0,0), sh=(10,400,400), number=0, export_class='all'): """ Predict and and save a selected slice from the training data as preview Parameters ---------- cube_i: int index of source cube in CNNData off: 3-tuple of int start index of slice to cut from cube (z,y,x) sh: 3-tuple of int shape of cube to cut (z,y,x) number: int consecutive number for the save name (i.e. hours, iterations etc.) export_class: str or int 'all' writes images of all classes, otherwise only the class with index ``export_class`` (int) is saved. """ if self.model.prediction_node.shape.ndim >= 2: pred_node = self.model.prediction_node elif "pred_dense" in self.model.nodes: pred_node = self.model['pred_dense'] else: raise RuntimeError("Model have spatial prediction node or" " 'pred_dense' node which is spatial") if self.model.ndim==3: min_z = self.model.prediction_node.input_nodes[0].shape['z'] if min_z > sh[0]: sh = list(sh) sh[0] = min_z elif self.model.ndim==2: pass else: raise RuntimeError("Model must be 2/3 dimensional for previews") raw_img = self.data.train_d[cube_i] raw_img = raw_img[:, off[0]:off[0]+sh[0], off[1]:off[1]+sh[1], off[2]:off[2]+sh[2]] self.predict_and_write(pred_node, raw_img, number, export_class) self.saved_raw_preview = True
[docs] def preview_slice(self, number=0, export_class='all', max_z_pred=5): """ Predict and and save a data from a separately loaded file as preview Parameters ---------- number: int/float consecutive number for the save name (i.e. hours, iterations etc.) export_class: str or int 'all' writes images of all classes, otherwise only the class with index ``export_class`` (int) is saved. max_z_pred: int approximate maximal number of z-slices to produce (depends on CNN architecture) """ assert self.preview_data is not None, "You must provide preview data in order to call this function" for example_no,raw_img in enumerate(self.preview_data): if raw_img.ndim==3: if raw_img.shape[0]>raw_img.shape[2]: raw_img = np.transpose(raw_img, (2,0,1)) logger.warning("preview_slice: transposing preview image, assuming last dim is z because " " this dim is smaller than the first.") z_sh = raw_img.shape[0] if raw_img.ndim==3 else raw_img.shape[1] if self.model.prediction_node.shape.ndim>=2: pred_node = self.model.prediction_node elif "pred_dense" in self.model.nodes: pred_node = self.model['pred_dense'] else: raise RuntimeError("Model have spatial prediction node or" " 'pred_dense' node which is spatial") if pred_node.shape.ndim==3: strd_z = pred_node.shape.strides[0] out_z = pred_node.shape.spatial_shape[0] * strd_z min_z = pred_node.input_nodes[0].shape.spatial_shape[0] + strd_z - 1 # input shape z_thick = min_z if out_z > max_z_pred else min_z + strd_z*int(np.ceil(float(max_z_pred-out_z)/strd_z)) elif pred_node.shape.ndim==2: z_thick = max_z_pred else: raise RuntimeError("Model must be 2/3 dimensional for previews") if z_thick > z_sh: raise ValueError("The preview slices are too small in z-direction for this CNN") if raw_img.ndim==3: raw_img = raw_img[None, (z_sh-z_thick)//2:(z_sh-z_thick)//2+z_thick, :, :] elif raw_img.ndim==4: raw_img = raw_img[:, (z_sh-z_thick)//2:(z_sh-z_thick)//2+z_thick, :, :] self.predict_and_write(pred_node, raw_img, number, export_class, example_no, max_z_pred) self.saved_raw_preview = True
###############################################################################
[docs]class TracingTrainer(Trainer): @staticmethod
[docs] def save_batch(img, lab, k, lab_img=None): img = img[0] lab = lab[0] off = img.shape[1] - lab.shape[1] utils.h5save(img, 'img-%i.h5' % k) if lab_img is not None: utils.h5save(lab_img, 'lab_img-%i.h5' % k) # assert off % 2==0 # off //= 2 for i in range(img.shape[1]): plt.imsave('batch-%i-z%i.png' % (k, i), img[0, i], cmap='gray')
# if 0 <= (i - off) < lab.shape[1]: # lab_small = lab[4, i - off] # lab_up = misc.imresize(lab_small, # np.multiply(lab_small.shape, 8), # interp='nearest') # plt.imsave('batch-%i-z%i-l.png' % (k, i), lab_up, cmap='gray') # def probmap_preview(self, raw_img, number=0, block_name=''): # """ # Predict and and save a slice as preview image # # Parameters # ---------- # # raw_img : np.ndarray # raw data in the format (ch, x, y, z) # number: int/float # consecutive number for the save name (i.e. hours, iterations etc.) # block_name: str # Name/number to distinguish different raw_imges # """ # block_name = str(block_name) # pred = self.model['pred_dense'].predict_dense(raw_img) # returns (k, z, y, x) # # save_name = self.exp_config.save_name # #names = ['vz', 'vx', 'vy', 'br', 'barr', 'bg', 'syn', 'ves', 'mito'] # for z in range(pred.shape[1]): # for c in range(pred.shape[0]): # plt.imsave('%s-pred-%s-z%i-c%i-%shrs.png' \ # % (save_name, block_name, z, c, number), # pred[c, z, :, :], cmap='gray') # # z_off = int(self.model['vec'].shape.offsets[0]) # for z in range(pred.shape[1]): # plt.imsave('%s-raw-%s-z%i.png' % (save_name, block_name, z), # raw_img[0, z + z_off, :, :], cmap='gray')
[docs] def debug_getcnnbatch(self, extended=False): """ Executes ``getbatch`` but with un-strided labels and always returning info. The first batch example is plotted and the whole batch is returned for inspection. """ kwargs = dict(self.get_batch_kwargs) kwargs['force_dense'] = True batch = self.data.getbatch(**kwargs) data, target = batch[0][0], batch[1][0] target[np.isclose(target, -666)] = 0 if self.model.ndim==2: if target.shape[0] >= 3: target = np.transpose(target, (1, 2, 0))[:, :, :3] target = target[...,[2, 0, 1]] else: target = target[0] plot_trainingtarget(data[0], target, 1) else: t_i = target.shape[1] // 2 if target.shape[0] >= 3: target = np.transpose(target, (1, 2, 3, 0))[0, :, :, :3] target = target[...,[2, 0, 2]] else: target = target[0, t_i] i = self.data.offsets[0] # z offset plot_trainingtarget(data[0, i+t_i], target, 1) if self.model.ndim==3 and extended: dest = '/tmp/%s-'%user_name data, target = batch[0], batch[1] target[np.isclose(target, -666)] = 0 i = self.data.offsets[0] # z offset for j in range(data.shape[2]): plt.imsave('/tmp/img-%i.png' % j, data[0, 0, j], cmap='gray') if j - i >= 0 and j - i < target.shape[2]: plt.imsave(dest+'img-%i-br.png'%j, target[0, 4, j - i],cmap='gray') plt.imsave(dest+'img-%i-z.png'%j, target[0,0,j-i], cmap='gray') plt.imsave(dest+'img-%i-y.png'%j, target[0,1,j-i], cmap='gray') plt.imsave(dest+'img-%i-x.png'%j, target[0,2,j-i], cmap='gray') plt.imsave(dest+'img-%i-barr.png'%j, target[0,3,j-i], cmap='gray') plt.imsave(dest+'img-%i-syn.png'%j, target[0,6,j-i], cmap='gray') plt.imsave(dest+'img-%i-ves.png'%j, target[0,7,j-i], cmap='gray') plt.imsave(dest+'img-%i-mito.png'%j, target[0,8,j-i], cmap='gray') quiver = my_quiver(target[0,2,j-i], target[0,1,j-i], img=target[0, 4, j - i], c=target[0,0,j-i]) quiver.savefig(dest+'vec-%i.png'%j, bbox_inches='tight') plt.ion() plt.show() plt.savefig('Batch_test_image.png', bbox_inches='tight') plt.pause(0.01) plt.pause(2.0) plt.close('all') plt.pause(0.01) return batch
[docs] def run(self): exp_config = self.exp_config save_name = exp_config.save_name data = self.data self.tracker.register_debug_output_names(self.model.debug_output_names[1:]) #remove first because it is prediction t_passed = 0 t_pt = 2 t_pi = 2 last_save_t = 0 last_save_t2 = 0 save_time = config.param_save_h save_time2 = config.initial_prev_h loss, loss_smooth, train_loss, valid_loss, train_error, valid_error, param_vars = 0, 0, 0, 0, 0, 0, 0 user_termination = False if isinstance(self.model.loss_node.parent[0], SquaredLoss): is_regression = True else: is_regression = False pp_err = 'err' if is_regression else '%' # -------------------------------------------------------------------------------------------------------- if config.background_processes: n_proc = max(2, int(config.background_processes)) bg_worker = BackgroundProc(data.getbatch, n_proc=n_proc, target_kwargs=self.get_batch_kwargs) # -------------------------------------------------------------------------------------------------------- try: lost_track = True tracing_length = 1 i = -1 t0 = time.time() while i < exp_config.n_steps: # update max every loop to make it modifyable during training max_tracing = exp_config.sequence_training if exp_config.sequence_training>1 else 50 try: if config.inspection: inspection_logger.info("#BATCH %i" % (i + 1)) # check if we lost the skeleton track if exp_config.sequence_training: try: lost_track= batch[2].lost_track except UnboundLocalError: pass if tracing_length >= max_tracing: lost_track = True # if we are still on track get a new slice from this skeleton if not lost_track and exp_config.sequence_training: tracing_length += 1 position_l, direction_il = batch[3].cnn_pred2lab_position( prediction_c) try: tmp = data.get_newslice(position_l, direction_il, **self.get_batch_kwargs) img, vec, trafo = tmp[:3] if len(tmp)==4: batch = (img, vec, batch[2], trafo, tmp[3]) else: batch = (img, vec, batch[2], trafo) # if get_newslice fails, do same as when track is lost except transformations.WarpingOOBError: lost_track = True if lost_track and exp_config.sequence_training: print("Traced for %i iterations" %(tracing_length,)) tracing_length = 0 try: skel = batch[2] skel.debug_traces.append(np.array(skel.debug_traces_current)) skel.debug_traces_current = [] skel.debug_grads.append(np.array(skel.debug_grads_current)) skel.debug_grads_current = [] except UnboundLocalError: pass # non-sequence training or getting a new skeleton if lost_track or (not exp_config.sequence_training): if config.background_processes: batch = bg_worker.get() else: batch = data.getbatch(**self.get_batch_kwargs) if exp_config.sequence_training: #print("Next skel: %i" % (batch[2],)) pass batch = list(batch) batch[2] = data.train_s[batch[2]] batch[3] = transformations.trafo_from_array(batch[3][0]) if config.inspection: lab_img = batch[4] batch = batch[:4] if (i+1) % 50==0: self.save_batch(batch[0], batch[1], (i + 1), lab_img) # ----------------------------------------------------------------------------------------------------- loss, t_per_train, debug_outputs = self.model.trainingstep(*batch, optimiser=exp_config.optimiser) # Update step prediction_c = debug_outputs[0][0] debug_outputs = debug_outputs[1:] i += 1 # ----------------------------------------------------------------------------------------------------- t_per_it = time.time() - t0 t_passed += t_per_it t0 = time.time() t_pi = 0.8 * t_pi + 0.2 * t_per_it # EMA loss_smooth = self.model.loss_smooth # check for divergence if np.any(np.isnan(loss)) or np.any(np.isinf(loss)): logger.warning(NLL_TEXT) raise KeyboardInterrupt # self.model.optimisers[exp_config.optimiser].repair_fuckup() self.tracker.update_timeline([t_passed, loss, debug_outputs[0]/10]) if debug_outputs: self.tracker.update_debug_outputs([i, loss, ] + debug_outputs) # Save Parameters # if (t_passed-last_save_t)/3600 > config.param_save_h: # last_save_t = t_passed # time_string = '-'+str(save_time)+'h' # self.model.save(os.path.join('Backup', save_name+time_string+'.mdl')) # save_time += config.param_save_h if i%config.param_save_it==0 and i>0: it_string = '-'+str(i//1000)+'k' self.model.save(os.path.join('Backup', save_name+it_string+'.mdl')) # Create preview prediction images if self.preview_data is not None: if (t_passed - last_save_t2) / 3600 > config.prev_save_h \ or (t_passed / 3600 > config.initial_prev_h and last_save_t2==0): # first time last_save_t2 = t_passed exp_config.preview_kwargs['number'] = save_time2 save_time2 += config.prev_save_h try: self.preview_slice(**exp_config.preview_kwargs) except: logger.warning("Preview Predictions failed." "Are the preview raw data in " "the correct format?") # reset time because we only count training time # not time spent for previews (making previews # is not a computational payload of the actual # training but just for "fun") t0 = time.time() # Adjust the learning rate and other schedule parameters for schedule in self.schedules.values(): if i==schedule.next_update: schedule.update(i) if (i % exp_config.history_freq==0) \ and exp_config.history_freq!=0: lr = self.model.lr mom = self.model.mom if len(self.model.gradnet_rates): gradnetrate = np.mean(self.model.gradnet_rates) else: gradnetrate = 0 ### Training & Valid Errors ### loss_after = self.model.loss(*batch) loss_gain = loss_after - loss train_loss, train_error = self.test_model('train') valid_loss, valid_error = self.test_model('valid') self.tracker.update_history([i, t_passed, train_loss, valid_loss, loss_gain, train_error, valid_error, lr, mom, gradnetrate]) ### Plotting / Saving ### self.model.save(save_name + '-LAST.mdl') self.tracker.save(os.path.join('Backup', save_name)) if config.plot_on and ((i>=exp_config.history_freq*3) or i>60): self.tracker.plot(save_name) if config.print_status: t = utils.pretty_string_time(t_passed) out = "%05i L_m=%.3f, L=%.2f, tr=%05.2f%s, " % (i, loss_smooth, loss, train_error, pp_err) out += "vl=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " \ % (valid_error, pp_err, debug_outputs[0] * 100, loss_gain) out += "LR=%.5f, %.2f it/s, %s" % ( lr, 1.0 / t_pi, t) logger.info(out) # User Interface ############################################## except (KeyboardInterrupt, ValueError, TypeError) as e: if not isinstance(e, KeyboardInterrupt): traceback.print_exc() print( "\nEntering Command line such that Exception can be " "further inspected by user.\n\n") out = "%05i L_m=%.5f, L=%.4f, train=%.5f, valid=%.5f, " % ( i, loss_smooth, loss, train_loss, valid_loss) out += "train=%.3f%s, valid=%.3f%s,\n LR=%.6f, MOM=%.6f, " \ % ( train_error, pp_err, valid_error, pp_err, self.model.lr, self.model.mom) out += "%.1f GPU-it/s, %.1f CPU-it/s, " % ( 1.0 / self.model.time_per_step, 1.0 / t_pi) t = utils.pretty_string_time(t_passed) logger.info(out + t) # Like a command line, but cannot change singletons var_push = globals() var_push.update(locals()) ret = trainutils.user_input(var_push) if ret=='kill': user_termination = True if config.background_processes: bg_worker.reset() plt.close('all') # reset time after user interaction, otherwise time # will appear as pause in plot t0 = time.time() # End UI ################################################## # This is in the epoch/UI loop if (t_passed > exp_config.max_runtime) or user_termination: logger.info('Timeout or manual Termination') break # This is OUTSIDE the training loop i.e. # the last block of the function ``run`` self.model.save(save_name + "-FINAL.mdl") if len(self.tracker.timeline) > 10: self.tracker.plot(save_name) logger.info('End of Training') logger.info('#' * 60 + '\n' + '#' * 60 + '\n') # -------------------end of run()---------------------------------- except: sys.excepthook(*sys.exc_info()) # show info on error finally: if config.background_processes: bg_worker.shutdown() if self.model.batch_normalisation_active: self.model = rebuild_model(self.model, replace_bn='const')
[docs] def test_model(self, data_source): """ Computes Loss and error/accuracy on batch with ``monitor_batch_size`` Parameters ---------- data_source: string 'train' or 'valid' Returns ------- Loss, error: """ assert self.batch_size==1 # copy because it is modified in next line! kwargs = dict(self.get_batch_kwargs) kwargs['source'] = data_source kwargs['batch_size'] = self.exp_config.monitor_batch_size try: batch = self.data.getbatch(**kwargs) except ValueError: logger.warning("Test model, getbatch failed. No validation data?") return np.nan, np.nan # 0, 0 batch = list(batch) batch[2] = list(batch[2]) batch[3] = list(batch[3]) for i in range(self.exp_config.monitor_batch_size): if data_source=='train': batch[2][i] = self.data.train_s[batch[2][i]] elif data_source=='valid': batch[2][i] = self.data.valid_s[batch[2][i]] batch[3][i] = transformations.trafo_from_array(batch[3][i]) if config.inspection: batch = batch[:4] rates = self.model.dropout_rates self.model.dropout_rates = ([0.0, ] * len(rates)) n = len(batch[0]) loss = 0 error = 0 for j in range(n): d = batch[0][j:j+1] # data l = batch[1][j:j+1] # target if len(batch) > 2: aux = [] for b in batch[2:]: aux.append(b[j]) nl, er, pred = self.model.predict_ext(d, l, *aux) skel = batch[2][j] # predict_ext calls get_loss_and_gradient # which adds current position and grad, but for testing model # we do not want these tracked --> remove again skel.debug_traces_current.pop() skel.debug_grads_current.pop() else: nl, er, pred = self.model.predict_ext(d, l) loss += nl * len(d) error += er * len(d) loss /= n error /= n self.model.dropout_rates = rates # restore old rates return loss, error
###############################################################################
[docs]class TracingTrainerRNN(TracingTrainer):
[docs] def run(self): exp_config = self.exp_config save_name = exp_config.save_name data = self.data if 'scan_out_radius_t'==self.model.debug_output_names[-1]: assert 'scan_out_radius'==self.model.debug_output_names[-2] print("Found regression targets for scan") self.tracker.register_debug_output_names(self.model.debug_output_names[:-2]) mem_hid = np.zeros(self.model['mem_hid'].shape, dtype=floatX) t_passed = 0 t_pt = 2 t_pi = 2 last_save_t = 0 last_save_t2 = 0 save_time = config.param_save_h save_time2 = config.initial_prev_h loss, loss_smooth, train_loss, valid_loss, train_error, valid_error, param_vars = 0, 0, 0, 0, 0, 0, 0 user_termination = False is_regression = True pp_err = 'err' # -------------------------------------------------------------------------------------------------------- try: lost_track = True tracing_length = 0 i = -1 t0 = time.time() skel_example = None while i < exp_config.n_steps: # update max every loop to make it modifyable during training max_tracing = exp_config.sequence_training try: if skel_example is None: skel_example, skel_index = data.getskel('train') if len(self.model.loss_node.input_nodes)==2: batch = (skel_example, mem_hid) else: batch = (skel_example,) skel_example.start_new_training = True tracing_length = 0 if config.inspection: inspection_logger.info("NEW SKEL %i" %skel_index) # ----------------------------------------------------------------------------------------------------- try: loss, t_per_train, debug_outputs = self.model.trainingstep(*batch, optimiser=exp_config.optimiser) # Update step tracing_length += 1 i += 1 if config.inspection: inspection_logger.info("-"*20) except transformations.WarpingOOBError: if config.inspection: inspection_logger.info("OOB, Traced for %i iterations" % (tracing_length,)) skel_example = None continue # ----------------------------------------------------------------------------------------------------- t_per_it = time.time() - t0 t_passed += t_per_it t0 = time.time() t_pi = 0.8 * t_pi + 0.2 * t_per_it # EMA loss_smooth = self.model.loss_smooth if skel_example.lost_track or tracing_length >= max_tracing: if config.inspection: inspection_logger.info("Traced for %i iterations" % (tracing_length,)) skel_example = None # check for divergence if np.any(np.isnan(loss)) or np.any(np.isinf(loss)): logger.warning(NLL_TEXT) raise KeyboardInterrupt # self.model.optimisers[exp_config.optimiser].repair_fuckup() self.tracker.update_timeline([t_passed, loss, 0]) if debug_outputs: if 'scan_out_radius_t'==self.model.debug_output_names[-1]: r = debug_outputs[-2] r_t = debug_outputs[-1] debug_outputs = debug_outputs[:-2] self.tracker.update_regression(r.ravel(), r_t.ravel()) debug_outputs_ = [np.mean(x) for x in debug_outputs] self.tracker.update_debug_outputs([i, loss, ] + debug_outputs_) # Save Parameters # if (t_passed-last_save_t)/3600 > config.param_save_h: # last_save_t = t_passed # time_string = '-'+str(save_time)+'h' # self.model.save(os.path.join('Backup', save_name+time_string+'.mdl')) # save_time += config.param_save_h if i%config.param_save_it==0 and i>0: it_string = '-'+str(i//1000)+'k' self.model.save(os.path.join('Backup', save_name+it_string+'.mdl')) # Create preview prediction images if self.preview_data is not None: if (t_passed - last_save_t2) / 3600 > config.prev_save_h \ or (t_passed / 3600 > config.initial_prev_h and last_save_t2==0): # first time last_save_t2 = t_passed exp_config.preview_kwargs['number'] = save_time2 save_time2 += config.prev_save_h try: self.preview_slice(**exp_config.preview_kwargs) except: logger.warning("Preview Predictions failed." "Are the preview raw data in " "the correct format?") # reset time because we only count training time # not time spent for previews (making previews # is not a computational payload of the actual # training but just for "fun") t0 = time.time() # Adjust the learning rate and other schedule parameters for schedule in self.schedules.values(): if i==schedule.next_update: schedule.update(i) if (i % exp_config.history_freq==0) \ and exp_config.history_freq!=0: lr = self.model.lr mom = self.model.mom if len(self.model.gradnet_rates): gradnetrate = np.mean(self.model.gradnet_rates) else: gradnetrate = 0 ### Training & Valid Errors ### loss_after = loss # too expensive to really compute for RNN loss_gain = loss_after - loss train_loss, train_error = self.test_model('train') valid_loss, valid_error = self.test_model('valid') self.tracker.update_history([i, t_passed, train_loss, valid_loss, loss_gain, train_error, valid_error, lr, mom, gradnetrate]) ### Plotting / Saving ### self.model.save(save_name + '-LAST.mdl') self.tracker.save(os.path.join('Backup', save_name)) if config.plot_on and ((i>=exp_config.history_freq*3) or i>60): self.tracker.plot(save_name) if config.print_status: t = utils.pretty_string_time(t_passed) out = "%05i L_m=%.3f, L=%.2f, tr=%05.2f%s, "% (i, loss_smooth, loss, train_error, pp_err) out += "vl=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " \ % (valid_error, pp_err, 0 * 100, loss_gain) out += "LR=%.5f, %.2f it/s, %s" % ( lr, 1.0 / t_pi, t) logger.info(out) # User Interface ############################################## except (KeyboardInterrupt, ValueError, TypeError) as e: if not isinstance(e, KeyboardInterrupt): traceback.print_exc() print( "\nEntering Command line such that Exception can be " "further inspected by user.\n\n") out = "%05i L_m=%.5f, L=%.4f, train=%.5f, valid=%.5f, " % ( i, loss_smooth, loss, train_loss, valid_loss) out += "train=%.3f%s, valid=%.3f%s,\n LR=%.6f, MOM=%.6f, " \ % ( train_error, pp_err, valid_error, pp_err, self.model.lr, self.model.mom) out += "%.1f GPU-it/s, %.1f CPU-it/s, " % ( 1.0 / self.model.time_per_step, 1.0 / t_pi) t = utils.pretty_string_time(t_passed) logger.info(out + t) # Like a command line, but cannot change singletons var_push = globals() var_push.update(locals()) ret = trainutils.user_input(var_push) if ret=='kill': user_termination = True plt.close('all') # reset time after user interaction, otherwise time # will appear as pause in plot t0 = time.time() # End UI ################################################## # This is in the epoch/UI loop if (t_passed > exp_config.max_runtime) or user_termination: logger.info('Timeout or manual Termination') break # This is OUTSIDE the training loop i.e. # the last block of the function ``run`` self.model.save(save_name + "-FINAL.mdl") if len(self.tracker.timeline) > 10: self.tracker.plot(save_name) logger.info('End of Training') logger.info('#' * 60 + '\n' + '#' * 60 + '\n') # -------------------end of run()---------------------------------- except: sys.excepthook(*sys.exc_info()) # show info on error finally: pass
# if self.model.batch_normalisation_active: # self.model = rebuild_model(self.model, replace_bn='const')
[docs] def test_model(self, data_source): #return 1.0, 0.5 """ Computes Loss and error/accuracy on batch with ``monitor_batch_size`` Parameters ---------- data_source: string 'train' or 'valid' Returns ------- Loss, error: """ try: skel_example, skel_index = self.data.getskel(data_source) except ValueError: logger.warning("Test model, getbatch failed. No validation data?") return np.nan, np.nan # 0, 0 rates = self.model.dropout_rates self.model.dropout_rates = ([0.0, ] * len(rates)) n = self.exp_config.monitor_batch_size loss = 0 j = 0 while j < n: #print(j, end="") try: skel_example, skel_index = self.data.getskel(data_source) skel_example.start_new_training = True nl = self.model.loss(skel_example) except transformations.WarpingOOBError: continue j += 1 loss += nl loss /= n self.model.dropout_rates = rates # restore old rates return loss, loss