# -*- coding: utf-8 -*-
# ELEKTRONN - Neural Network Toolkit
#
# Copyright (c) 2014 - now
# Max-Planck-Institute for Medical Research, Heidelberg, Germany
# Authors: Marius Killinger
from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, \
super, zip
import os
from collections import OrderedDict
from matplotlib import pyplot as plt
from scipy import stats
import numpy as np
import seaborn as sns
import logging
logger = logging.getLogger('elektronn2log')
def _scroll_plot1(image, name, init_z):
"""
Creates a plot of 3d volume images
Scrolling changes the displayed slices
Parameters
----------
images: array of shape (z,x,y) or (z,x,y,RGB)
Usage
-----
For the scroll interation to work, the "scroller" object
must be returned to the calling scope
>>> fig, scroller = scroll_plot(image, name)
>>> fig.show()
"""
fig = plt.figure(figsize=(12, 12))
ax1 = fig.add_subplot(111)
scroller = Scroller([ax1], [image, ], [name, ], init_z)
fig.canvas.mpl_connect('scroll_event', scroller.onscroll)
fig.tight_layout()
return scroller
def _scroll_plot2(images, names, init_z):
"""
Creates a plot 1x2 image plot of 3d volume images
Scrolling changes the displayed slices
Parameters
----------
images: list of 2 arrays
Each array of shape (z,y,x) or (z,y,x,RGB)
names: list of 2 strings
Names for each image
Usage
-----
For the scroll interation to work, the "scroller" object
must be returned to the calling scope
>>> fig, scroller = _scroll_plot4(images, names)
>>> fig.show()
"""
fig = plt.figure(figsize=(12, 6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122, sharex=ax1, sharey=ax1)
scroller = Scroller([ax1, ax2], images, names, init_z)
fig.canvas.mpl_connect('scroll_event', scroller.onscroll)
fig.tight_layout()
return scroller
def _scroll_plot4(images, names, init_z):
"""
Creates a plot 2x2 image plot of 3d volume images
Scrolling changes the displayed slices
Parameters
----------
images: list of 4 arrays
Each array of shape (z,y,x) or (z,y,x,RGB)
names: list of 4 strings
Names for each image
Usage
-----
For the scroll interation to work, the "scroller" object
must be returned to the calling scope
>>> fig, scroller = _scroll_plot4(images, names)
>>> fig.show()
"""
fig = plt.figure(figsize=(12, 12))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222, sharex=ax1, sharey=ax1)
ax3 = fig.add_subplot(223, sharex=ax1, sharey=ax1)
ax4 = fig.add_subplot(224, sharex=ax1, sharey=ax1)
scroller = Scroller([ax1, ax2, ax3, ax4], images, names, init_z)
fig.canvas.mpl_connect('scroll_event', scroller.onscroll)
fig.tight_layout()
return scroller
def _embed3d2d(a, border_width=1, normalize=False, output_ratio=1.5, ):
"""
Embed an 3d array into an 2d matrix by tiling.
The last two dimensions of ``a`` are assumed to be spatial, the first is tiled.
"""
sh = a.shape
assert len(sh)==3
n = sh[0]
nhor = int(np.ceil(np.sqrt(n * output_ratio))) # aim: ratio 16:9
nvert = int(np.ceil(float(n) / nhor)) # warning: too big: nvert*nhor >= n
if normalize:
maxs = [np.max(a[i, :, :]) + 1e-8 for i in range(n)]
mins = [np.min(a[i, :, :]) for i in range(n)]
else:
maxs = [1] * n
mins = [0] * n
ret = np.zeros(
(nvert * (border_width + sh[1]), nhor * (border_width + sh[2])),
dtype=np.float32)
for j in range(nvert):
for i in range(nhor):
if i + j * nhor >= n:
return ret
ret[j*(border_width+sh[1]):j*(border_width+sh[1])+sh[1],
i*(border_width+sh[2]):i*(border_width+sh[2])+sh[2]] = \
(a[i+j*nhor,:,:]-mins[i+j*nhor])/(maxs[i+j*nhor]-mins[i+j*nhor])
return ret
[docs]def embedfilters(filters, border_width=1, normalize=False, output_ratio=1.0,
rgb_axis=None):
"""
Embed an nd array into an 2d matrix by tiling.
The last two dimensions of ``a`` are assumed to be spatial,
the others are tiled recursively.
"""
if rgb_axis is not None:
assert filters[rgb_axis]==3
channels = []
for i in range(3):
slice = [slice(None), ] * filters.ndim
slice[rgb_axis] = i
f = filters[slice]
channels.append(
embedfilters(f, border_width, normalize, output_ratio))
return np.dstack(channels)
if filters.ndim==3:
return _embed3d2d(filters, border_width, normalize, output_ratio)
elif filters.ndim > 3:
parts = []
for f in filters:
parts.append(
embedfilters(f, border_width, normalize, output_ratio))
parts = np.concatenate([x[None, ...] for x in parts])
return embedfilters(parts, border_width, normalize, output_ratio)
[docs]def sma(c, n):
"""
Returns box-SMA of c with box length n, the returned array has the same
length as c and is const-padded at the beginning
"""
if n==0:
return c
ret = np.cumsum(c, dtype=float)
ret[n:] = (ret[n:] - ret[:-n]) / n
m = min(n, len(c))
ret[:n] = ret[:n] / np.arange(1, m + 1) # unsmoothed
return ret
[docs]def add_timeticks(ax, times, steps, time_str='mins', num=5):
N = int(times[-1])
k = max(N / num, 1)
k = int(np.log10(k)) # 10-base of locators
m = int(np.round(float(N) / (num * 10 ** k))) # multiple of base
s = max(m * 10 ** k, 1)
x_labs = np.arange(0, N, s, dtype=np.int)
x_ticks = np.interp(x_labs, times, steps)
ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labs)
ax.set_xlim(0, steps[-1])
ax.set_xlabel('Runtime [%s]' % time_str) # (%s)'%("{0:,d}".format(N)))
[docs]def plot_hist(timeline, history, save_name, loss_smoothing_length=200,
autoscale=True):
"""Plot graphical info during Training"""
plt.ioff()
try:
# Subsample points for plotting
N = len(timeline)
x_timeline = np.arange(N)
s = max((len(timeline) // 2000), 1)
x_timeline = x_timeline[::s]
timeline = timeline[::s]
s = max((len(history) // 2000), 1)
history = history[::s]
if timeline['time'][-1] < 120 * 60:
runtime = str(int(timeline['time'][-1] / 60)) + ' mins'
else:
runtime = "%.1f hours" % (timeline['time'][-1] / 3600)
# check if valid data is available
if not np.any(np.isnan(history['valid_loss'])):
l = history['valid_loss'][-10:]
else:
l = history['train_loss'][-10:]
loss_cap = l.mean() + 2 * l.std()
lt = timeline['loss'][-200:]
lt_m = lt.mean()
lt_s = lt.std()
loss_cap_t = lt_m + 2 * lt_s
loss_cap = np.maximum(loss_cap, loss_cap_t)
if np.all(timeline['loss'] > 0):
loss_floor = 0.0
else:
loss_floor = lt_m - 2 * lt_s
### Timeline, Loss ###
plt.figure(figsize=(16, 12))
plt.subplot(211)
plt.plot(x_timeline, timeline['loss'], 'b-', alpha=0.5,
label='Update Loss')
loss_smooth = sma(timeline['loss'], loss_smoothing_length)
plt.plot(x_timeline, loss_smooth, 'k-', label='Smooth update Loss',
linewidth=3)
if autoscale:
plt.ylim(loss_floor, loss_cap)
plt.xlim(0, N)
plt.legend(loc=0)
plt.hlines(lt_m, 0, N, linestyle='dashed', colors='r', linewidth=2)
plt.hlines(lt_m + lt_s, 0, N, linestyle='dotted', colors='r',
linewidth=1)
plt.hlines(lt_m - lt_s, 0, N, linestyle='dotted', colors='r',
linewidth=1)
plt.xlabel('Update steps %s, total runtime %s' % (N - 1, runtime))
ax = plt.twiny()
if timeline['time'][-1] > 120 * 60:
add_timeticks(ax, timeline['time'] / 3600, x_timeline,
time_str='hours')
else:
add_timeticks(ax, timeline['time'] / 60, x_timeline,
time_str='mins')
### Loss vs Prevalence ###
plt.subplot(212)
c = 1.0 - (timeline['time'] / timeline['time'].max())
plt.scatter(timeline['batch_char'], timeline['loss'], c=c, marker='.',
s=80, cmap='gray', edgecolors='face')
if autoscale:
bc = timeline['batch_char'][-200:]
bc_m = bc.mean()
bc_s = bc.std()
bc_cap = bc_m + 2 * bc_s
if np.all(bc > 0):
bc_floor = -0.01
else:
bc_floor = bc_m - 2 * bc_s
plt.ylim(loss_floor, loss_cap)
plt.xlim(bc_floor, bc_cap)
plt.xlabel('Mean target of batch')
plt.ylabel('Loss')
plt.tight_layout()
plt.savefig(save_name + ".timeline.png", bbox_inches='tight')
###################################################################
### History Loss ###
plt.figure(figsize=(16, 12))
plt.subplot(311)
plt.plot(history['steps'], history['train_loss'], 'g-',
label='Train Loss', linewidth=3)
plt.plot(history['steps'], history['valid_loss'], 'r-',
label='Valid Loss', linewidth=3)
if autoscale:
plt.ylim(loss_floor, loss_cap)
plt.xlim(0, history['steps'][-1])
plt.legend(loc=0)
# plt.xlabel('Update steps %s, total runtime %s'%(N-1, runtime))
ax = plt.twiny()
if timeline['time'][-1] > 120 * 60:
add_timeticks(ax, timeline['time'] / 3600, x_timeline,
time_str='hours')
else:
add_timeticks(ax, timeline['time'] / 60, x_timeline,
time_str='mins')
### History Loss gains ###
plt.subplot(312)
plt.plot(history['steps'], history['loss_gain'], 'b-',
label='Loss Gain at update', linewidth=3)
plt.hlines(0, 0, history['steps'][-1], linestyles='dotted')
plt.plot(history['steps'], history['lr'], 'r-', label='LR',
linewidth=3)
# plt.xlabel('Update steps %s, total runtime %s'%(N-1, runtime))
plt.legend(loc=3)
std = history['loss_gain'][:5].std() * 2 if len(history) > 6 else 1.0
if autoscale:
# add epsilon to suppress matplotlib warning in case of CG
plt.ylim(-std, std + 1e-10)
plt.xlim(0, history['steps'][-1])
ax2 = plt.twinx()
ax2.plot(history['steps'], history['mom'], 'r-', label='MOM')
ax2.plot(history['steps'], history['gradnetrate'], 'r-',
label='GradNetRate')
ax2.set_ylim(-1, 1)
if autoscale:
ax2.set_xlim(0, history['steps'][-1])
ax2.legend(loc=4)
### Errors ###
plt.subplot(313)
cutoff = 2
if len(history) > (cutoff + 1):
history = history[cutoff:]
# check if valid data is available
if not np.any(np.isnan(history['valid_err'])):
e = history['valid_err'][-200:]
else:
e = history['train_err'][-200:]
e_m = e.mean()
e_s = e.std()
err_cap = e_m + 2 * e_s
if np.all(e > 0):
err_floor = 0.0
else:
err_floor = e_m - 2 * e_s
plt.plot(history['steps'], history['train_err'], 'g--',
label='Train error', linewidth=1)
plt.plot(history['steps'], history['valid_err'], 'r--',
label='Valid Error', linewidth=1)
plt.plot(history['steps'], sma(history['train_err'], 8), 'g-',
label='Smooth train error', linewidth=3)
if not np.any(np.isnan(sma(history['valid_err'], 8))):
plt.plot(history['steps'], sma(history['valid_err'], 8), 'r-',
label='Smooth valid Error', linewidth=3)
if autoscale:
plt.ylim(err_floor, err_cap)
plt.xlim(0, history['steps'][-1])
plt.grid()
plt.legend(loc=0)
plt.xlabel('Update steps %s, total runtime %s' % (N - 1, runtime))
plt.tight_layout()
plt.savefig(save_name + ".history.png", bbox_inches='tight')
except ValueError:
# When arrays are empty
logger.warning("An error occurred during plotting.")
[docs]def plot_var(var, save_name):
# [i, nll, nll.std, conc.mean, conc.std,]
plt.figure(figsize=(16, 12))
plt.subplot(211)
plt.plot(var[:, 0], var[:, 1], 'b-', alpha=0.6)
plt.plot(var[:, 0], sma(var[:, 1], 100), 'g-', linewidth=3)
plt.plot(var[:, 0], sma(var[:, 1] + var[:, 2], 100), 'r:', linewidth=2)
plt.plot(var[:, 0], sma(var[:, 1] - var[:, 2], 100), 'r:', linewidth=2)
plt.title("NLL")
plt.subplot(212)
plt.plot(var[:, 0], var[:, 3], 'b-', alpha=0.6)
plt.plot(var[:, 0], sma(var[:, 3], 100), 'g-', linewidth=3)
plt.plot(var[:, 0], sma(var[:, 3] + var[:, 4], 100), 'r:', linewidth=2)
plt.plot(var[:, 0], sma(var[:, 3] - var[:, 4], 100), 'r:', linewidth=2)
plt.title("Concentration")
plt.savefig(save_name + ".Beta1.png", bbox_inches='tight')
plt.figure(figsize=(12, 12))
c = 1.0 - ((var[:, 0]).astype(np.float32) / var[-1, 0])
plt.subplot(221)
plt.scatter(var[:, 1], var[:, 3], c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
plt.title("Concentration vs. NLL")
plt.subplot(222)
plt.scatter(var[:, 2], var[:, 3], c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
plt.title("Concentration vs. NLL.std")
plt.subplot(223)
plt.scatter(var[:, 3], var[:, 4], c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
plt.title("Concentration vs. Concentration.std")
plt.subplot(224)
plt.scatter(var[:, 1], var[:, 2], c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
plt.title("NLL vs. NLL.std")
plt.savefig(save_name + ".Beta2.png", bbox_inches='tight')
[docs]def plot_debug(var, debug_output_names, save_name):
# [i, nll, other....]
s = max((len(var) // 2000), 1)
var = var[::s]
plt.figure(figsize=(16, 12))
colors = ['gold', 'b', 'darkblue', 'crimson', 'navajowhite', 'deepskyblue',
'darkgray', 'maroon', 'palevioletred', 'forestgreen', ] * 2
n = len(colors) // 2
marker = ['-', ] * n + [':'] * n
lw_s = [2, ] * n + [3, ] * n
maxima = []
minima = []
total = sma(var[:, 1], 70)
maxima.append(total[-100:].max())
minima.append(total[-100:].min())
plt.plot(var[:, 0], total, 'k-', linewidth=4, label='total loss')
for i in range(len(debug_output_names)): ###TODO automatic std intervals
name = debug_output_names[i]
smooth = sma(var[:, i + 2], 70)
plt.plot(var[:, 0], smooth, color=colors[i], linestyle=marker[i],
linewidth=lw_s[i], label=name)
maxima.append(smooth[-100:].max())
minima.append(smooth[-100:].min())
plt.title("Debug Outputs")
cap_hi = np.max([x for x in maxima if np.isfinite(x)]) * 1.5
cap_lo = np.min([x for x in minima if np.isfinite(x)])
plt.ylim(cap_lo, cap_hi)
plt.legend(loc=0)
plt.hlines(0, var[0, 0], var[-1, 0], linewidth=1)
plt.grid()
plt.savefig(save_name + ".Debug.png", bbox_inches='tight')
[docs]def plot_regression(pred, target, save_name, loss_smoothing_length=200,
autoscale=True):
"""Plot graphical info during Training"""
try:
# Subsample points for plotting
N = len(pred)
s = max((len(pred) // 2000), 1)
pred = pred[::s].ravel()
target = target[::s].ravel()
N = len(pred)
x_timeline = np.arange(N)
c = N - x_timeline
plt.figure(figsize=(8, 8))
### Loss vs Prevalence ###
plt.scatter(pred, target, c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
m = np.minimum(pred.min(), target.min())
M = np.maximum(pred.max(), target.max())
plt.plot([m, M], [m, M], 'r:')
plt.ylim(m, M)
plt.xlim(m, M)
plt.xlabel('Prediction')
plt.ylabel('Target')
plt.tight_layout()
plt.savefig(save_name + ".regression.png", bbox_inches='tight')
except ValueError:
# When arrays are empty
logger.warning("An error occurred during regression plotting.")
[docs]def plot_kde(pred, target, save_name, limit=90, scale='same', grid=50,
take_last=4000):
try:
if take_last:
pred = pred[-take_last:].ravel()
target = target[-take_last:].ravel()
if limit=='max':
mp, mt = pred.min(), target.min()
Mp, Mt = pred.max(), target.max()
else:
lo = 100 - limit
mp, mt = np.percentile(pred, lo), np.percentile(target, lo)
Mp, Mt = np.percentile(pred, limit), np.percentile(target, limit)
if scale=='same':
mp = np.minimum(mp, mt)
Mp = np.maximum(Mp, Mt)
mt = mp
Mt = Mp
if isinstance(grid, int):
grid = [grid, grid]
pg, tg = np.mgrid[mp:Mp:grid[0] * 1j, mt:Mp:grid[1] * 1j]
positions = np.vstack([pg.ravel(), tg.ravel()])
values = np.vstack([pred, target])
kernel = stats.gaussian_kde(values)
f = np.reshape(kernel(positions).T, pg.shape)
plt.figure()
plt.xlim(mp, Mp)
plt.ylim(mt, Mt)
plt.xlabel("Prediction")
plt.ylabel("Target")
plt.imshow(np.rot90(f), cmap=plt.cm.gist_earth_r,
extent=[mp, Mp, mt, Mt])
plt.contour(pg, tg, f)
plt.plot([mt, Mt], [mt, Mt], 'r:')
plt.tight_layout()
plt.savefig(save_name + ".regression_kde.png", bbox_inches='tight')
except ValueError:
# When arrays are empty
logger.warning("An error occurred during regression kde plotting.")
[docs]def my_quiver(x, y, img=None, c=None):
"""
first dim of x,y changes along vertical axis
second dim changes along horizontal axis
x: vertical vector component
y: horizontal vector component
"""
figure = plt.figure(figsize=(7, 7))
if img is not None:
plt.imshow(img, interpolation='none', alpha=0.22, cmap='gray')
plt.quiver(x, y, c, angles='xy', units='xy', cmap='spring', pivot='middle',
scale=0.5)
return figure
[docs]def plot_trainingtarget(img, lab, stride=1):
"""
Plots raw image vs target to check if valid batches are produced.
Raw data is also shown overlaid with targets
Parameters
----------
img: 2d array
raw image from batch
lab: 2d array
targets
stride: int
strides of targets
"""
if len(lab) * stride!=len(img):
off = (len(img) - stride * len(lab)) // 2 // stride
if lab.ndim==3:
assert lab.shape[2]==3
new_t = np.zeros(
(lab.shape[0] + 2 * off, lab.shape[1] + 2 * off, 3))
new_t[off:-off, off:-off, :] = lab
else:
new_t = np.zeros((lab.shape[0] + 2 * off, lab.shape[1] + 2 * off))
new_t[off:-off, off:-off] = lab
lab = new_t
if lab.ndim==3:
assert lab.shape[2]==3
img = img[:, :, None]
img = np.repeat(img, 3, axis=2)
plt.figure(figsize=(18, 6))
plt.subplot(131)
plt.imshow(img, interpolation='none', cmap=plt.get_cmap('gray'))
plt.title('data')
plt.subplot(132)
plt.imshow(lab, interpolation='none', cmap=plt.get_cmap('gray'))
plt.title('target')
if img.shape==lab.shape:
overlay = 0.75 * img + 0.25 * (1 - lab)
plt.subplot(133)
plt.imshow(overlay, interpolation='none', cmap=plt.get_cmap('gray'))
plt.title('overlay')
plt.show()
return img - lab
[docs]def plot_exectimes(exectimes, save_path='~/exectimes.png', max_items=32):
"""
Plot model execution time dict obtained from
elektronn2.neuromancer.model.Model.measure_exectimes()
:param exectimes: OrderedDict of execution times
(output of Model.measure_exectimes())
:param save_path: Where to save the plot
:param max_items: Only the max_items largest execution times are given
names and are plotted independently.
Everything else is grouped under '(other nodes)'.
"""
thresh_val = 0
if len(exectimes) > max_items:
thresh_val = sorted(list(exectimes.values()))[-max_items]
filt_rtimes = OrderedDict()
for key, val in exectimes.items():
if val >= thresh_val:
filt_rtimes[key] = val
other = sum(exectimes.values()) - sum(filt_rtimes.values())
node_names = list(filt_rtimes.keys())
node_exectimes = list(filt_rtimes.values())
if len(exectimes) > max_items:
node_names += ['(other nodes)']
node_exectimes += [other]
cs = plt.cm.Set1(np.arange(len(node_exectimes)) / (len(node_exectimes)))
sns.set_style("whitegrid")
plt.figure(figsize=(13, 12))
plt.title('Node execution times')
plt.ylabel('Node')
plt.xlabel('Time (in ms)')
ax = sns.barplot(y=node_names, x=node_exectimes)
ax.get_figure().savefig(os.path.expanduser(save_path), bbox_inches='tight')