# -*- coding: utf-8 -*-
# ELEKTRONN2 Toolkit
# Copyright (c) 2015 Marius Killinger
# All rights reserved
from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, super, zip
# TODO: Python 3 compatibility
__all__ = ['trace_zyx2xyz', 'trace_to_kzip', 'SkeletonMFK',
'Trace']
import os
import sys
from subprocess import check_call
import logging
from collections import OrderedDict
from scipy import interpolate
from scipy import sparse
from scipy.sparse import csgraph
import numpy as np
from knossos_utils import skeleton as knossos_skeleton # TODO: Mark as dependency when knossos_utils is published.
from .. import utils
from ..config import config
from . import transformations
logger = logging.getLogger('elektronn2log')
inspection_logger = logging.getLogger('elektronn2log-inspection')
with open(os.devnull, 'w') as devnull:
# mayavi is to dumb to raise an exception and instead crashes whole script....
try:
# "xset q" will always succeed to run if an X server is currently running
check_call(['xset', 'q'], stdout=devnull, stderr=devnull)
import mayavi.mlab as mlab
# Don't set backend explicitly, use system default...
# if "xset q" fails, conclude that X is not running
except: # (OSError, ImportError, CalledProcessError, ValueError)
logger.warning("No mayavi imported, cannot plot skeletons")
mlab = None
###############################################################################
# Constants for scaling of radius
REF_RADIUS = 20.0
BASE = 1.3
BASE_I = BASE ** -1
HYST = 0.75 # 0.5: no memory 1.0 complete non-overlap
assert 0.5 <= HYST <= 1.0
BASE_H = BASE ** HYST
BASE_IH = BASE ** -HYST
@utils.timeit
@utils.my_jit(nopython=True)
def insert(cube, coords, i, off):
for k in np.arange(coords.shape[0]):
cube[coords[k,0]-off[0], coords[k,1]-off[1], coords[k,2]-off[2]] = i
@utils.timeit
@utils.my_jit(nopython=True)
def insert_vec(cube, coords, vec, off):
n = len(coords)
m = len(vec[0])
double_inserts = 0
for i in np.arange(n):
for j in np.arange(m):
if abs(cube[coords[i,0]-off[0],
coords[i,1]-off[1],
coords[i,2]-off[2],j])>1e-5:
double_inserts += 1
cube[coords[i,0]-off[0], coords[i,1]-off[1], coords[i,2]-off[2],j] = np.nan # in case of doubt, dont train here...
else:
cube[coords[i,0]-off[0], coords[i,1]-off[1], coords[i,2]-off[2],j] = vec[i, j]
return double_inserts
@utils.timeit
@utils.my_jit(nopython=True)
def ray_cast(max_dists, hull_points, hull_dist, ray_steps, hull_cube, off):
s = np.float32(0.9) # step length
sh = hull_cube.shape
for i in np.arange(len(hull_points)): # take hull point
# initialise dist and position
dist = hull_dist[i] + 1e-5
x = hull_points[i, 0] - off[0]
y = hull_points[i, 1] - off[1]
z = hull_points[i, 2] - off[2]
found = False
count = 0
while True:
count += 1
x = x + s * ray_steps[i,0]
y = y + s * ray_steps[i,1]
z = z + s * ray_steps[i,2]
if np.int(x)<0.0 or np.int(y)<0.0 or np.int(z)<0.0:
break
if np.int(x+0.5)>=sh[0] or np.int(y+0.5)>=sh[1] or np.int(z+0.5)>=sh[2]:
break
# search if hull is True in neighbourhood of x,y,z
found = hull_cube[np.int(x), np.int(y), np.int(z)] or \
hull_cube[np.int(x+0.5), np.int(y), np.int(z)] or \
hull_cube[np.int(x), np.int(y+0.5), np.int(z)] or \
hull_cube[np.int(x), np.int(y), np.int(z+0.5)] or \
hull_cube[np.int(x+0.5), np.int(y+0.5), np.int(z)] or \
hull_cube[np.int(x+0.5), np.int(y), np.int(z+0.5)] or \
hull_cube[np.int(x), np.int(y+0.5), np.int(z+0.5)] or \
hull_cube[np.int(x+0.5), np.int(y+0.5), np.int(z+0.5)]
if not found:
break
if count>200:
break
dist = dist + s
max_dists[i] = dist
@utils.my_jit(nopython=True, cache=True)
def find_peaks_helper(padded_cube, peak_cube):
sh = padded_cube.shape
for z in np.arange(1,sh[0]-1):
for x in np.arange(1,sh[1]-1):
for y in np.arange(1,sh[2]-1):
center = padded_cube[z,x,y]
is_peak = center >= padded_cube[z-1, x-1, y-1] and \
center >= padded_cube[z+1, x+1, y+1] and \
center >= padded_cube[z+1, x+1, y-1] and \
center >= padded_cube[z+1, x-1, y+1] and \
center >= padded_cube[z-1, x+1, y+1] and \
center >= padded_cube[z+1, x-1, y-1] and \
center >= padded_cube[z-1, x-1, y+1] and \
center >= padded_cube[z-1, x+1, y-1] and \
center >= padded_cube[z, x, y+1] and \
center >= padded_cube[z, x, y-1] and \
center >= padded_cube[z, x+1, y] and \
center >= padded_cube[z, x-1, y] and \
center >= padded_cube[z+1, x, y] and \
center >= padded_cube[z-1, x, y] and \
center >= padded_cube[z, x+1, y+1] and \
center >= padded_cube[z, x-1, y-1] and \
center >= padded_cube[z+1, x+1, y] and \
center >= padded_cube[z-1, x-1, y] and \
center >= padded_cube[z+1, x, y+1] and \
center >= padded_cube[z-1, x, y-1] and \
center >= padded_cube[z, x-1, y+1] and \
center >= padded_cube[z, x+1, y-1] and \
center >= padded_cube[z+1, x-1, y] and \
center >= padded_cube[z+1, x-1, y] and \
center >= padded_cube[z-1, x, y+1] and \
center >= padded_cube[z-1, x, y+1]
if is_peak:
peak_cube[z-1,x-1,y-1] = center
def find_peaks(cube):
padded_cube = np.pad(cube, 1, mode='constant')
peaks = np.zeros_like(cube)
find_peaks_helper(padded_cube, peaks)
#peak_label, n = ndimage.label(peaks)
#coordinates = ndimage.center_of_mass(peaks, peak_label, index=np.arange(1,n+1))
indices = np.flatnonzero(peaks)
maxima = cube.ravel()[indices]
sort_ix = np.argsort(maxima)
return indices[sort_ix], maxima[sort_ix]
# WARNING / NOTE: skeleton objects are in xyz-order
[docs]class SkeletonMFK(object):
"""
Joints: all branches and end points / node terminatons (nodes not of deg 2)
Branches: Joints of degree >= 3
"""
@staticmethod
[docs] def find_joints(node_list):
joints = {}
branches = {}
for node in node_list:
if node.degree() > 2: # branching point
joints[node.ID] = node
branches[node.ID] = node
if node.degree()==1:
joints[node.ID] = node # end point
return joints, branches
def __init__(self, aniso_scale=2, name=None, skel_num=None):
self.aniso_scale = np.array([[1,1,aniso_scale]], dtype=np.float32)
self.bones = dict()
self.edges = list()
self.branches = dict()
self.joints = dict()
self.all_nodes = None
self.hull_points = None
self.hull_skel = dict()
self.hull_branch = dict()
self.name = name
self.skel_num = skel_num
self.radii = dict()
self.all_radii = None
self.joint_radii = None
self.props = dict()
self.all_props = None
self.joint_props = None
self.joint_id2joint_index = dict()
# For training
self.kdt_hull = None
self.linked_data = None
self.lost_track = False
self.position_s = None
self.position_l = None
self.direction_il = None
self.start_new_training = True
self.prev_batch = None
self.trafo = None
self.prev_scale = 1.0
self.prev_gamma = 0.0
self.training_traces = []
self.background_processes = False
self._hull_point_bg = dict()
self.cnn_grid = None
# Old for training
self.debug_traces = []
self.debug_traces_current = []
self.debug_grads = []
self.debug_grads_current = []
[docs] def init_from_annotation(self, skeleton_annotatation, min_radius=None,
interpolation_resolution=0.5, interpolation_order=1):
# Read annotation data structures and convert to dicts and np.ndarrays
#print(len(skeleton_annotatation.getNodes()))
self.joints, self.branches = self.find_joints(skeleton_annotatation.getNodes())
#print(len(self.joints), len(self.branches))
visited = {n: False for n in skeleton_annotatation.getNodes()}
for joint_id, joint in self.joints.items():
directions = joint.getNeighbors()
for d in directions:
if visited[d]: # we have visited this bone already
continue
visited[d] = True # mark as visited
bone = OrderedDict() # create new bone
bone[joint] = True # start the bone at the joint
current_node = d # next go to the node in the selected direction
while True:
bone[current_node] = True
if current_node.degree() > 2 or current_node.degree()==1:
# At new branch or end point, the bone ends here
# add edge between starting joint and this branch
if joint_id < current_node.ID:
edge = (joint_id, current_node.ID)
else:
edge = (current_node.ID, joint_id)
self.edges.append(edge)
break
else: # The node has 2 neibgs, one from which we come and
# another one to which we go
nb = list(current_node.getNeighbors())
assert len(nb) == 2
# Test which node we visit next
if nb[0] in bone:
assert nb[1] not in bone
current_node = nb[1]
if nb[1] in bone:
assert nb[0] not in bone
current_node = nb[0]
self.bones[edge] = list(bone.keys())
# Convert bones to arrays
for edge, bone in self.bones.items():
self.bones[edge] = np.array([x.getCoordinate() for x in bone], dtype=np.float32)
self.radii[edge] = np.array([x.getDataElem('radius') for x in bone], dtype=np.float32)
try:
axoness_pred = np.array([x.getDataElem('axoness_pred') for x in bone], dtype=np.int16)
spiness_pred = np.array([x.getDataElem('spiness_pred') for x in bone], dtype=np.int16)
props = np.concatenate([axoness_pred[:,None], spiness_pred[:,None]], axis=1)
self.props[edge] = props
except KeyError:
pass
# convert joints to arrays
self.joint_radii = np.array([x.getDataElem('radius') for x in self.joints.values()], dtype=np.float32)
try:
axoness_pred = np.array([x.getDataElem('axoness_pred') for x in self.joints.values()], dtype=np.int16)
spiness_pred = np.array([x.getDataElem('spiness_pred') for x in self.joints.values()], dtype=np.int16)
self.joint_props = np.concatenate([axoness_pred[:,None], spiness_pred[:,None]], axis=1)
except KeyError:
pass
self.joint_id2joint_index = dict(zip(self.joints.keys(), range(len(self.joints))))
self.joints = np.array([x.getCoordinate() for x in self.joints.values()], dtype=np.float32)
# convert branches to arrays
self.branches = np.array([x.getCoordinate() for x in self.branches.values()], dtype=np.float32)
if interpolation_resolution is not None:
for edge, bone in self.bones.items():
if len(bone)<=1:
continue
try:
new_bone = self.interpolate_bone(bone,max_k=interpolation_order,
resolution=interpolation_resolution)
self.radii[edge] = self.interpolate_prop(bone, self.radii[edge],
new_bone)
except:
bone, keep_index = utils.unique_rows(bone)
new_bone = self.interpolate_bone(bone,max_k=interpolation_order,
resolution=interpolation_resolution)
self.radii[edge] = self.interpolate_prop(bone, self.radii[edge][keep_index],
new_bone)
try:
self.props[edge] = self.interpolate_prop(bone,
self.props[edge],new_bone, discrete=True)
except:
pass
self.bones[edge] = new_bone
self.all_nodes = np.vstack([self.joints,] + list(self.bones.values()))
self.all_radii = np.hstack([self.joint_radii,] + list(self.radii.values()))
try:
self.all_props = np.vstack([self.joint_props, ] + list(self.props.values()))
except:
pass
if min_radius:
self.all_radii = np.maximum(self.all_radii, min_radius)
[docs] def save(self, fname):
utils.picklesave(self, fname)
[docs] def interpolate_bone(self, bone, max_k=1, resolution=0.5):
bone_iso = bone * self.aniso_scale
linear_distances = np.linalg.norm(np.diff(bone_iso, axis=0), axis=1)
total_dist = linear_distances.sum()
k = min(max_k, bone_iso.shape[0]-1)
tck, u = interpolate.splprep(bone_iso.T, k=k)
n = max(2, int(float(total_dist) / resolution))
new = interpolate.splev(np.linspace(0,1,n), tck)
new = np.array(new).T / self.aniso_scale
return new
[docs] def interpolate_prop(self, old_bone, old_prop, new_bone, discrete=False):
dtype = np.int16 if discrete else np.float32
new_prop = np.zeros((len(new_bone),)+old_prop.shape[1:], dtype=dtype)
old_bone_iso = old_bone * self.aniso_scale
new_bone_iso = new_bone * self.aniso_scale
start_i = 0
stop_i = 1
min_dist = np.linalg.norm(new_bone_iso[0] - old_bone_iso[stop_i])
for i in range(len(new_bone)):
dist_start = np.linalg.norm(new_bone_iso[i] - old_bone_iso[start_i])
dist_stop = np.linalg.norm(new_bone_iso[i] - old_bone_iso[stop_i])
min_dist = min(min_dist, dist_stop)
if (min_dist < dist_stop) and stop_i+1<len(old_bone):
stop_i += 1
start_i += 1
dist_start = dist_stop
dist_stop = np.linalg.norm(new_bone_iso[i] - old_bone_iso[stop_i])
min_dist = dist_stop
if discrete:
if dist_stop > dist_start:
new_prop[i] = old_prop[start_i]
else:
new_prop[i] = old_prop[stop_i]
else:
d = dist_start + dist_stop
new_prop[i] = dist_stop/d * old_prop[start_i] + dist_start/d * old_prop[stop_i]
return new_prop
@utils.cache()
[docs] def get_kdtree(self, static_points, k=1, jobs=-1):
kdt = utils.KDT(n_neighbors=k, n_jobs=jobs,
algorithm='kd_tree', leaf_size=20)
kdt.fit(static_points * self.aniso_scale) # change metric)
#assert np.all(kdt._fit_X / self.aniso_scale == static_points)
return kdt
@utils.cache()
[docs] def get_knn(self, kdt, query_points, k=None):
if k is not None:
pass
#assert k==kdt.n_neighbors
else:
k = kdt.n_neighbors
distances, indices = kdt.kneighbors(query_points * self.aniso_scale, n_neighbors=k) # change metric)
static_points = kdt._fit_X.astype(np.float32) # Attention those still have the aniso scale in [:,2]
if k==1:
indices = indices[:,0]
distances = distances[:,0].astype(np.float32)
coordinates = static_points[indices] / self.aniso_scale # change to pixel coordinates
else:
distances = distances.astype(np.float32)
coordinates = static_points[indices] / self.aniso_scale # change to pixel coordinates
assert coordinates.shape[1] == k
return distances, indices, coordinates
[docs] def get_closest_node(self, position_s):
kdt = self.get_kdtree(self.all_nodes, k=1, jobs=1)
dist, ind, nearest_s = self.get_knn(kdt, position_s)
if position_s.ndim==1:
dist = dist[0]
ind = ind[0]
nearest_s = nearest_s[0]
return dist.astype(np.float32), ind, nearest_s
### Sampling routines for getting training data ###
[docs] def sample_skel_point(self, rng, joint_ratio=None):
n = len(self.all_nodes)
if joint_ratio:
if rng.rand() < joint_ratio:
n = len(self.joints)
i = rng.randint(n)
node = self.all_nodes[i]
return node, i
[docs] def sample_tube_point(self, rng, r_max_scale=0.9, joint_ratio=None):
"""
This is skeleton node based sampling:
Go to a random node, sample a random orthogonal direction
go a random distance into direction (uniform over the
[0, r_max_scale * local maximal radius])
"""
# tt = utils.Timer()
if self.hull_points is None:
kdt = None
else:
if self.kdt_hull is None:
raise RuntimeError("Hull kdts must be pre initialised")
kdt = self.kdt_hull
node, node_i = self.sample_skel_point(rng, joint_ratio)
direc_iso = self.sample_local_direction_iso(node)
local_r = self.all_radii[node_i] * r_max_scale
count = 0
max_count = 30
proposal = node
clipped = False
while True:
r = rng.rand() * local_r
phi = rng.rand() * 2 * np.pi
cos_theta = rng.rand() * 2 - 1
sin_theta = np.sqrt(1 - cos_theta ** 2)
x = np.cos(phi) * sin_theta
y = np.sin(phi) * sin_theta
z = cos_theta
rand_vec = np.array([x, y, z])
orthogonal_vec_iso = np.cross(direc_iso, rand_vec)
orthogonal_vec_iso /= np.linalg.norm(orthogonal_vec_iso)
orthogonal_vec = orthogonal_vec_iso / self.aniso_scale[0]
proposal = node + orthogonal_vec * r
if kdt is None:
return proposal
dist, ind, coord = self.get_knn(kdt, proposal)
dist = dist[0]
if dist < 1.5: # we are within hull:
break
if count >= max_count / 2 and not clipped:
local_r *= 0.5
clipped = True
logger.debug("Sample hull point: clipped r")
if count >= max_count:
logger.debug(
"Sample hull point: max count %i reached" % max_count)
proposal = node
break
# tt.check("\tdouble_check")
count += 1
return proposal
[docs] def sample_local_direction_iso(self, point, n_neighbors=6):
"""
For a point gives the local skeleton direction/orientation by
fitting a line through the nearest neigbours, sign is randomly assigned
"""
kdt = self.get_kdtree(self.all_nodes, k=n_neighbors, jobs=1)
dist, ind, coord = self.get_knn(kdt, point)
dist = dist[0]
ind = ind[0]
coord = coord[0]
# maybe use dist as weights for svd?
neibs_iso = coord * self.aniso_scale # transform to iso space
uu, dd, vv = np.linalg.svd(neibs_iso - neibs_iso.mean(axis=0))
direc_iso = vv[0] # take largest eigenvector
direc_iso /= np.linalg.norm(direc_iso, axis=0) # normalise
return direc_iso
[docs] def sample_tracing_direction_iso(self, rng, local_direction_iso, c=0.5):
"""
Sample a direction close to the local direction
there is a prior so that the normalised (0,1) angle of deviation a
has this distribution:
p(a) = 1/N * (1-c*a), where N= 1 - c/2,
tmp is the inverse cdf of this shit
"""
if rng.rand() > 0.5: # the sign is undefined, choose randomly
local_direction_iso *= -1
u = rng.rand()
tmp = (1 - np.sqrt(1-(2*c - c**2)*u)) / c # theta scaled between 0 and 1
# theta scaled between 0 and 90 deg in rad i.e. 0 and pi/2
theta = tmp * 0.5 * np.pi
max_count = 1000
count = 0
proposal = local_direction_iso
while True:
proposal = rng.rand(3) * 2 - 1
proposal /= np.linalg.norm(proposal, axis=0) # normalise
cos_alpha = np.dot(proposal, local_direction_iso)
if cos_alpha < 0: # flip to next best within +/- 90 deg
cos_alpha *= -1
proposal *= -1
alpha = np.arccos(cos_alpha)
if alpha < theta + 0.01:
break
count += 1
if count>max_count:
logger.debug("Sample tracing directions: max count reached")
break
return proposal
### Loss and loss gradient for Theano Graph ###
[docs] def get_loss_and_gradient(self, new_position_s, cutoff_inner=1.0/3,
rise_factor=0.1):
"""
prediction_c (zxy)
Zoned error surface:
flat in inner hull (selected at cutoff_inner)
constant gradient in "outer" hull towards nearest inner hull voxel
gradient increasing with distance (scaled by rise_factor)
for predictions outside hull
"""
inner_hull, indices = self.get_hull_points_inner(cutoff_inner,
return_indices=True)
kdt = self.get_kdtree(inner_hull, k=1, jobs=1)
dist, ind, nearest_s = self.get_knn(kdt, new_position_s)
dist = dist[0]
ind = ind[0]
nearest_s = nearest_s[0]
if config.inspection:
inspection_logger.info("nearest_s: %s"% (nearest_s.tolist()))
if dist<1.5: # we are within inner hull. The maximal distance if
# within hull is excatly: np.linalg.norm(np.multiply(
# [0.5, 0.5, 0.6], [1, 1, 2])) = 1.22... --> add some margin
loss = 0.0
grad_s = np.zeros((3,), dtype=np.float32)
self.lost_track = False
else:
loss = dist
# max dist of closest node
max_dist = self.hull_skel['max_dist'][indices[ind]]
# pointing from nearest to new position
unit_grad = (new_position_s - nearest_s)
unit_grad /= np.linalg.norm(unit_grad * self.aniso_scale[0], axis=0)
if max_dist > dist: # we are in hull but not in inner tube
grad_s = unit_grad * 1.0
self.lost_track = False
else: # we are outside hull
self.lost_track = True
factor = rise_factor * (dist - max_dist)
grad_s = unit_grad * (1 + factor)
self.debug_traces_current.append(new_position_s)
self.debug_grads_current.append(grad_s)
loss = np.array([loss,], dtype=np.float32)
return loss, grad_s
def _new_training_trace(self, **get_batch_kwargs):
"""
Preprate skeleton for a new training (sample location/direction,
reset stuff)
Parameters
----------
get_batch_kwargs
"""
#tt = utils.Timer()
if self.current_trace:
if len(self.training_traces)>20:
self.training_traces = self.training_traces[-2:]
self.training_traces.append(self.current_trace)
self.current_trace = Trace(linked_skel=self)
r_max_scale = get_batch_kwargs['r_max_scale']
tracing_dir_prior_c = get_batch_kwargs['tracing_dir_prior_c']
joint_ratio = get_batch_kwargs.get('joint_ratio', None)
position_s = self.sample_tube_point(self.linked_data.rng,
r_max_scale=r_max_scale,
joint_ratio=joint_ratio)
if config.inspection:
inspection_logger.info("Start new training")
local_direc_is = self.sample_local_direction_iso(position_s,
n_neighbors=6)
tracing_direc_is = self.sample_tracing_direction_iso(self.linked_data.rng,
local_direc_is,
c=tracing_dir_prior_c)
self.position_s = position_s
self.position_l = position_s[::-1] # from lab2data (xyz)->(zxy)
self.direction_il = tracing_direc_is[::-1] # from lab2data (xyz)->(zxy)
self.current_trace.append(position_s, coord_cnn=[0,]*3,
grad=[0,]*3, features=[0,]*7)
self.lost_track = False
self.trafo = None
#tt.check("final")
@staticmethod
[docs] def get_scale_factor(radius, old_factor, scale_strenght):
"""
Parameters
----------
radius: predicted radius (not the true radius)
old_factor: factor by which the radius prediction and the image was scaled
scale_strenght: limits the maximal scale factor
Returns
-------
new_factor
"""
# if old was large (zoom in), radius is smaller
hi = BASE ** (scale_strenght * 2) + 1e-3 # e.g 1.69 for 1.3**2
lo = BASE_I ** (scale_strenght * 4) - 1e-3 # e.g. 0.35 for 1/1.3 ** 4
radius_true = radius / old_factor
new_factor = REF_RADIUS / radius_true
new_factor = np.clip(new_factor, lo, hi)
change = new_factor / old_factor
if new_factor > 1.0: # left side
if change >= BASE_H: # growing
new_factor = old_factor * BASE
elif change < BASE_IH:
new_factor = old_factor * BASE_I
else:
new_factor = old_factor
elif new_factor < 1.0: # right side
if change <= BASE_IH: # zoom out
new_factor = old_factor * BASE_I
elif change > BASE_H: # zoome in
new_factor = old_factor * BASE
else:
new_factor = old_factor
else:
new_factor = old_factor
if config.inspection:
inspection_logger.info("SCALE: %.2f -> %.2f, factor0: %.2f, factor: %.2f"
% (radius, radius_true, 20.0 / radius_true, new_factor))
return new_factor
@staticmethod
@utils.cache
def make_grid(t_grid_sh, z_shift):
"""
Parameters
----------
t_grid_sh: tagged shape (pixel shape + strides)
z_shift: shift of center (positive means more look ahead)
Returns
-------
points: coordinate list zyx order
zz,yy,xx: coordinate meshgrid
"""
sh = np.array(t_grid_sh.spatial_shape)
st = np.array(t_grid_sh.strides)
lim = (sh-1) * st + 1
lim //= 2
zz,yy,xx = np.mgrid[-lim[0]:lim[0]:1j * sh[0],
-lim[1]:lim[1]:1j * sh[1],
-lim[2]:lim[2]:1j * sh[2]]
zz += z_shift
points = np.hstack([zz.ravel()[:,None],
yy.ravel()[:,None],
xx.ravel()[:,None]]).astype(np.float32)
return points, zz,yy,xx
@staticmethod
[docs] def point_potential(r, margin_scale, size, repulsion=None):
if repulsion is None:
repulsion = 1.0
left = margin_scale * size
x = (r - left)/(size - left)
v = 1.0 - (x**3*(x*(x*6 - 15) + 10)) # soft step function
v = np.minimum(np.maximum(v, 0.0), 1.0)
return v * repulsion
[docs] def getbatch(self, prediction, scale_strenght, **get_batch_kwargs):
"""
Parameters
----------
prediction: [[new_position_c, radius, ]]
scale_strenght: limits the maximal scale factor for zoom
get_batch_kwargs
Returns
-------
batch: img, target_img, target_grid, target_node
"""
get_batch_kwargs = dict(get_batch_kwargs) # copy because we destory it
if self.start_new_training:
self._new_training_trace(**get_batch_kwargs)
self.start_new_training = False
scale = 1.0
self.prev_scale = 1.0
self.prev_gamma = np.random.rand() * 2 * np.pi
elif np.allclose(prediction, 0):
scale = self.prev_scale
if config.inspection:
inspection_logger.warning("getbatch with no feedback: either "
"training on same skel or error")
else:
prediction = prediction[0]
new_position_c = prediction[:3]
radius = prediction[3] # this is just the predicted val, not the true
new_position_l, tracing_direc_il = self.trafo.cnn_pred2lab_position(new_position_c)
new_position_s = new_position_l[::-1]
self.position_s = new_position_s
self.position_l = new_position_s[::-1] # from lab2data (xyz)->(zxy)
self.direction_il = tracing_direc_il
scale = self.get_scale_factor(radius, self.prev_scale, scale_strenght)
self.prev_scale = scale
grid = get_batch_kwargs.pop('grid', False)
t_grid_sh = get_batch_kwargs.pop('t_grid_sh', None)
z_shift = get_batch_kwargs['z_shift']
get_batch_kwargs.pop('joint_ratio', None)
try:
if config.inspection:
inspection_logger.info("Getslice from position_l %s in "
"direction_il %s, SCALE %.2f"%(np.array_str(
self.position_l, precision=1,
suppress_small=True),
self.direction_il, scale))
get_batch_kwargs['gamma'] = self.prev_gamma
data_batch = self.linked_data.get_newslice(self.position_l,
self.direction_il,
scale=scale,
**get_batch_kwargs)
img, target_img, trafo = data_batch[:3]
if grid:
raise RuntimeError("The creation of the grid target must"
"be testet for spatial coherence again")
if not self.cnn_grid:
self.cnn_grid = self.make_grid(t_grid_sh, z_shift)
grid_coords_c, zz, yy, xx = self.cnn_grid
#dir_point_s = self.position_s + self.direction_il[::-1]/self.aniso_scale[0]
#dir_momentum_s = dir_point_s - self.current_trace.coords[-4:].mean(0)
#dir_momentum_ci = trafo.lab_coord2cnn_coord(dir_momentum_s[::-1])*[2,1,1]
#directions_ci = grid_coords_c*[2,1,1]
#direction_difference = cdist(directions_ci, dir_momentum_ci[None], 'cosine') # 0..2, 45deg thresh: > 1.7
#direction_difference = (direction_difference[:,0] - 1.7).astype(np.float32)
#direction_difference[direction_difference<0.0] = 0.0
#direction_difference[np.isnan(direction_difference)] = 0.0 # center if even is NULL
#repulsion = 1.0 - direction_difference * 2 # * strength, without factor ~ -25%
### TODO might also make repulsion depending on skel_node instead of grid_position. No WHY?
### TODO repulsion is not smooth enough
repulsion = 1.0
grid_coords = trafo.cnn_coord2lab_coord(grid_coords_c,add_offset_l=True)
dist, ind, nearest_s = self.get_closest_node(grid_coords[:,::-1])
radii = self.all_radii[ind]
target_grid = self.point_potential(dist, 0.1, radii, repulsion)
target_grid = target_grid.reshape(zz.shape)[None] # add channel
if np.allclose(target_grid, 0.0):
logger.warning("WTF")
self.debug_shit = [img, target_grid]
self.debug_shit2 = [nearest_s, radii]
else:
target_grid = np.ones((1,1,1,1), dtype=np.float32)
# Get bio labels/classes
dist, ind, nearest_s = self.get_closest_node(self.position_s)
classes = self.all_props[ind]
target_node = np.zeros(7, dtype=np.float32)
target_node[classes[0]+1] = 1
target_node[classes[1]+4] = 1
target_node[0] = self.all_radii[ind] * scale
if config.inspection:
inspection_logger.info("target_node %s, (true r: %.1f)"
%(target_node, self.all_radii[ind]))
batch = (img, target_img, target_grid, target_node)
self.trafo = trafo
return batch
except transformations.WarpingOOBError:
if config.inspection: inspection_logger.info("OOB in getbatch")
raise transformations.WarpingOOBError("Batch OOB")
[docs] def step_feedback(self, new_position_s, new_direction_is, pred_c,
pred_features, cutoff_inner=1.0/3, rise_factor=0.1):
inner_hull, indices = self.get_hull_points_inner(cutoff_inner, return_indices=True)
kdt = self.get_kdtree(inner_hull, k=1, jobs=1)
dist, ind, nearest_s = self.get_knn(kdt, new_position_s)
dist = dist[0]
ind = ind[0]
nearest_s = nearest_s[0]
# we are within inner hull. The maximal distance if within hull is 1.2...
if dist < 1.5:
loss = 0.0
grad_s = np.array([0, 0, 0], dtype=np.float32)
else:
loss = dist
max_dist = self.hull_skel['max_dist'][indices[ind]] # max dist of closest node
unit_grad = (new_position_s - nearest_s) # pointing from nearest to new position
unit_grad /= np.linalg.norm(unit_grad * self.aniso_scale[0],
axis=0) # normalise grad
if max_dist > dist: # we are in hull but not in inner tube
grad_s = unit_grad * 1.0
else: # we are outside hull
self.lost_track = True
if config.inspection:
inspection_logger.info("Lost track")
factor = rise_factor * (dist - max_dist)
grad_s = unit_grad * (1 + factor)
self.current_trace.append(new_position_s, coord_cnn=pred_c,
grad=grad_s, features=pred_features)
# Actually the new positions should be set in getbach, but we need to
# set them here to because sometimes getbatch might be called without
# "start_new_training" and with only zeros as prediction
self.position_s = new_position_s
self.position_l = new_position_s[::-1] # from lab to data frame (xyz) -> (zxy)
self.direction_il = new_direction_is[::-1] # from lab to data frame (xyz) -> (zxy)
loss = np.array([loss,], dtype=np.float32)
return loss, grad_s, nearest_s
[docs] def step_grid_update(self, grid, radius, bio):
pred_features = np.hstack([radius, bio])
flat_indices, scores = find_peaks(grid[0,0])
grid_coords_c, zz, yy, xx = self.cnn_grid
preds_c = grid_coords_c[flat_indices]
#preds_l = self.trafo.cnn_coord2lab_coord(preds_c,add_offset_l=True)
if len(scores):
new_position_c = grid_coords_c[flat_indices[-1]]
preds_c = new_position_c[None]
else:
new_position_c = np.array([2,0,0], dtype=np.float32)
new_position_l, tracing_direc_il = self.trafo.cnn_pred2lab_position(new_position_c)
new_position_s = new_position_l[::-1]
new_direction_is = tracing_direc_il[::-1]
if config.inspection:
inspection_logger.info("GridUpdate, node pred %s" % (
np.array_str(pred_features, precision=2,
suppress_small=True),))
inspection_logger.info(
"GridUpdate, new_position_c: %s, new_position_l: %s" % (
new_position_c, np.array_str(new_position_l, precision=1,
suppress_small=True)))
if config.inspection>1:
img, grid_t = self.debug_shit
utils.picklesave([img[0,0], grid_t[0], grid[0,0]], '/tmp/debug_shit_skel_%i' %self.skel_num)
self.current_trace.append(new_position_s, coord_cnn=new_position_c,
features=pred_features)
# Actually the new positions should be set in getbach, but we need to
# set them here to because sometimes getbatch might be called without
# "start_new_training" and with only zeros as prediction
self.position_s = new_position_s
self.position_l = new_position_s[::-1] # from lab to data frame (xyz) -> (zxy)
self.direction_il = new_direction_is[::-1] # from lab to data frame (xyz) -> (zxy)
return new_position_c[None], preds_c, scores
### Plotting ###
[docs] def plot_skel(self, fig=None):
if fig is None:
fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400))
x = self.all_nodes[:,0]
y = self.all_nodes[:,1]
z = self.all_nodes[:,2]*self.aniso_scale[0,2]
mlab.points3d(x,y,z, scale_factor=0.8, color=(1,0,0), figure=fig)
for bone in self.bones.values():
x = bone[:,0]
y = bone[:,1]
z = bone[:,2]*self.aniso_scale[0,2]
mlab.plot3d(x,y,z,tube_radius=0.4, color=(0.3,0.3,0.3), figure=fig)
self._plot_joints(fig=fig)
return fig
[docs] def plot_debug_traces(self, grads=True, fig=None):
if fig is None:
fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400))
traces = np.array(self.debug_traces)
for trace in traces:
x = trace[:, 0]
y = trace[:, 1]
z = trace[:, 2] * self.aniso_scale[0, 2]
mlab.plot3d(x, y, z, tube_radius=0.2, color=(0.3, 0.3, 0.3), figure=fig)
if grads:
grads = np.array(self.debug_grads)
for grad, trace in zip(grads, traces):
x = trace[:, 0]
y = trace[:, 1]
z = trace[:, 2] * self.aniso_scale[0, 2]
gx = -grad[:, 0]
gy = -grad[:, 1]
gz = -grad[:, 2] * self.aniso_scale[0, 2]
mlab.quiver3d(x,y,z, gx, gy, gz, figure=fig,
color=(0,0.6,0.2), scale_factor=3)
return fig
[docs] def plot_radii(self, fig=None):
if fig is None:
fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400))
x = self.all_nodes[:,0]
y = self.all_nodes[:,1]
z = self.all_nodes[:,2]*self.aniso_scale[0,2]
r = self.all_radii
mlab.points3d(x,y,z,r, scale_mode='scalar', scale_factor=1,
color=(0,0.5,0.5), mode='sphere', opacity=0.1, figure=fig)
return fig
def _plot_joints(self, fig=None):
if fig is None:
fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400))
x = self.joints[:,0]
y = self.joints[:,1]
z = self.joints[:,2]*self.aniso_scale[0,2]
mlab.points3d(x,y,z, scale_factor=3, color=(1,1,0), figure=fig)
return fig
### Hull methods ###
[docs] def calc_max_dist_to_skels(self):
hull = self.hull_points # (n, 3)
direc = self.hull_skel['direc'] #(n, 3)
dist = self.hull_skel['dist'] #(n) # true distances
max_dist = np.zeros(len(hull), dtype=np.float32)
# This ray has unit magnitude in the true metric
ray_steps = -direc/(np.linalg.norm(direc * self.aniso_scale, axis=1)[:,None]+1e-5)
# create dense cube and insert hull
sh = np.max(hull,0) + 1
off = np.min(hull,0)
sh -= off
hull_cube = np.zeros(sh, dtype=np.bool)
insert(hull_cube, hull, True, off)
# cast rays through dense cube
ray_cast(max_dist, hull, dist, ray_steps, hull_cube, off)
# in this case the magnitude of the direc vector is 0 anyway
max_dist[np.any(~np.isfinite(ray_steps), axis=1)] = 1.0
rel_dist = dist / max_dist
return max_dist, rel_dist
[docs] def map_hull(self, hull_points):
"""
Distances take already into account the anisotropy in z
(i.e. they are true distances)
But all coordinates for hulls and vectors are still pixel coordinates
"""
self.hull_points = hull_points.astype(np.int16)
hull_points = hull_points.astype(np.float32)
kdt_skel = self.get_kdtree(self.all_nodes)
dist_skel, ind_skel, coord_skel = self.get_knn(kdt_skel, hull_points)
self.hull_skel['dist'] = dist_skel
self.hull_skel['ind'] = ind_skel
self.hull_skel['direc'] = coord_skel - hull_points ## NNs - Queries
max_dist, rel_dist = self.calc_max_dist_to_skels()
self.hull_skel['max_dist'] = max_dist
self.hull_skel['rel_dist'] = rel_dist
if len(self.branches):
kdt_branch = self.get_kdtree(self.branches)
dist_branch, ind_branch, coord_branch = self.get_knn(kdt_branch, hull_points)
self.hull_branch['dist'] = dist_branch
self.hull_branch['ind'] = ind_branch
self.hull_branch['direc'] = coord_branch - hull_points
else:
self.hull_branch['dist'] = np.zeros(len(hull_points), dtype=np.float32)
self.hull_branch['ind'] = None
self.hull_branch['direc'] = np.zeros((len(hull_points),3), dtype=np.float32)
if not np.all(np.isfinite(dist_skel)) or \
not np.all(np.isfinite(self.hull_branch['dist'])):
raise ValueError("InfiniteValue")
self.kdt_hull = self.get_kdtree(self.hull_points, k=1, jobs=1) # store for later usa
@utils.cache()
[docs] def get_hull_points_inner(self, cutoff=1.0/3, return_indices=False):
mask = self.hull_skel['rel_dist'] < cutoff
if return_indices:
return self.hull_points[mask], mask.nonzero()[0]
else:
return self.hull_points[mask]
@utils.cache()
[docs] def get_hull_branch_direc_cutoff(self, cutoff=25, normalise=False):
mask = self.hull_branch['dist'] < cutoff
ret = self.hull_branch['direc'] * mask #[mask]
if normalise:
ret /= (self.hull_branch['dist'][:,None]+1e-5)
return ret
@utils.cache()
[docs] def get_hull_branch_dist_cutoff(self, cutoff=25, normalise=True):
mask = self.hull_branch['dist'] < cutoff
ret = self.hull_branch['dist'] * mask #[mask]
if normalise:
ret = (ret > 0)
return ret
@utils.cache()
[docs] def get_hull_skel_direc_rel(self):
return self.hull_skel['direc'] / self.hull_skel['max_dist'][:,None]
[docs] def plot_hull(self, fig=None):
if fig is None:
fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400))
x = self.hull_points[:,0]
y = self.hull_points[:,1]
z = self.hull_points[:,2]*self.aniso_scale[0,2]
mlab.points3d(x,y,z, scale_factor=1, color=(1,1,1), mode='cube',
opacity=0.1, figure=fig)
return fig
[docs] def plot_hull_inner(self, cutoff, fig=None):
if fig is None:
fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400))
inner_hull = self.get_hull_points_inner(cutoff)
x = inner_hull[:,0]
y = inner_hull[:,1]
z = inner_hull[:,2]*self.aniso_scale[0,2]
mlab.points3d(x,y,z, scale_factor=1, color=(0.8,0.8,1), mode='cube',
opacity=0.1, figure=fig)
return fig
[docs] def plot_vec(self, substep=15, dict_name='skel', key='direc', vec=None, fig=None):
if fig is None:
fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400))
x = self.hull_points[:,0]
y = self.hull_points[:,1]
z = self.hull_points[:,2]*self.aniso_scale[0,2]
x, y, z = x[::substep], y[::substep], z[::substep]
if vec is None:
dict_ = self.hull_skel if dict_name=='skel' else self.hull_branch
u = dict_[key][:,0]
v = dict_[key][:,1]
w = dict_[key][:,2]*self.aniso_scale[0,2]
else:
u = vec[:,0]
v = vec[:,1]
w = vec[:,2]*self.aniso_scale[0,2]
u,v,w = u[::substep], v[::substep], w[::substep]
mlab.quiver3d(x,y,z, u,v,w, figure=fig)
return fig
[docs]class Trace(object):
"""
Unless otherwise state all coordinates are in skeleton system (xyz) with
z-axis anisotrope and all distances are in pixels (conversion to mu: 1/100)
"""
def __init__(self, linked_skel=None, aniso_scale=2,max_cutoff=200,
uturn_detection_k=40, uturn_detection_thresh=0.45,
uturn_detection_hold=10, feature_count=7):
self.aniso_scale = np.array([[1, 1, aniso_scale]], dtype=np.float32)
self.skel = linked_skel
self.lost_track = False
self.uturn_occurred = False
self.coords = utils.AccumulationArray(right_shape=3, n_init=500)
self.seg_length = utils.AccumulationArray(n_init=500)
self.runlengths = utils.AccumulationArray(n_init=500)
self.dist_self = utils.AccumulationArray(right_shape=2, n_init=500)
self.dist_skel = utils.AccumulationArray(n_init=500)
self.uturn_mask = utils.AccumulationArray(n_init=500, dtype=np.bool)
self.coords_cnn = utils.AccumulationArray(right_shape=3, n_init=500)
self.grads = utils.AccumulationArray(right_shape=3, n_init=500)
self.features = utils.AccumulationArray(right_shape=feature_count, n_init=500)
self.max_cutoff = max_cutoff
self.uturn_detection_k = uturn_detection_k
self.uturn_detection_thresh = uturn_detection_thresh
self.uturn_detection_hold = uturn_detection_hold
self.kdt = utils.DynamicKDT(k=uturn_detection_k, n_jobs=1,
aniso_scale=self.aniso_scale)
self.root = 0
self.comment = ""
[docs] def new_reverted_trace(self):
new_trace = Trace(self.skel, self.aniso_scale[0,2],
self.max_cutoff, self.uturn_detection_k, self.uturn_detection_thresh,
self.uturn_detection_hold, self.features.data.shape[1:])
new_trace.coords = utils.AccumulationArray(data=self.coords[::-1])
new_trace.seg_length = utils.AccumulationArray(data=self.seg_length[::-1])
new_trace.runlengths = utils.AccumulationArray(data=self.runlengths[-1]-self.runlengths[::-1])
new_trace.dist_self = utils.AccumulationArray(data=self.dist_self[::-1])
new_trace.dist_skel = utils.AccumulationArray(data=self.dist_skel[::-1])
new_trace.uturn_mask = utils.AccumulationArray(data=self.uturn_mask[::-1])
new_trace.coords_cnn = utils.AccumulationArray(data=self.coords_cnn[::-1])
new_trace.grads = utils.AccumulationArray(data=self.grads[::-1])
new_trace.features = utils.AccumulationArray(data=self.features[::-1])
if len(new_trace)<=self.uturn_detection_k:
kdt = utils.DynamicKDT(k=self.uturn_detection_k, n_jobs=1,
aniso_scale=self.aniso_scale)
for c in new_trace.coords.data:
kdt.append(c)
new_trace.kdt = kdt
else:
new_trace.kdt = utils.DynamicKDT(points=new_trace.coords.data,
k=self.uturn_detection_k,
n_jobs=1,
aniso_scale=self.aniso_scale)
new_trace.root = len(self)-1
try:
self.comment
except AttributeError:
self.comment = ""
new_trace.comment = self.comment+ " R"
return new_trace
[docs] def new_cut_trace(self, start, stop):
new_trace = Trace(self.skel, self.aniso_scale[0,2],
self.max_cutoff, self.uturn_detection_k, self.uturn_detection_thresh,
self.uturn_detection_hold, self.features.data.shape[1:])
new_trace.coords = utils.AccumulationArray(data=self.coords[start:stop])
new_trace.seg_length = utils.AccumulationArray(data=self.seg_length[start:stop])
new_trace.runlengths = utils.AccumulationArray(data=self.runlengths[start:stop]-self.runlengths[start])
new_trace.dist_self = utils.AccumulationArray(data=self.dist_self[start:stop])
new_trace.dist_skel = utils.AccumulationArray(data=self.dist_skel[start:stop])
new_trace.uturn_mask = utils.AccumulationArray(data=self.uturn_mask[start:stop])
new_trace.coords_cnn = utils.AccumulationArray(data=self.coords_cnn[start:stop])
new_trace.grads = utils.AccumulationArray(data=self.grads[start:stop])
new_trace.features = utils.AccumulationArray(data=self.features[start:stop])
if len(new_trace)<=self.uturn_detection_k:
kdt = utils.DynamicKDT(k=self.uturn_detection_k, n_jobs=1,
aniso_scale=self.aniso_scale)
for c in new_trace.coords.data:
kdt.append(c)
new_trace.kdt = kdt
else:
new_trace.kdt = utils.DynamicKDT(points=new_trace.coords.data,
k=self.uturn_detection_k,
n_jobs=1,
aniso_scale=self.aniso_scale)
if (self.root - start) >= 0 and (self.root - start) < len(new_trace):
new_trace.root = self.root - start
else:
new_trace.root = None #np.minimum(len(new_trace)-1, self.root - start)
try:
self.comment
except AttributeError:
self.comment = ""
new_trace.comment = self.comment + "C%i-%i"%(start, stop)
return new_trace
def __len__(self):
return len(self.coords)
[docs] def save(self, fname):
utils.picklesave(self, fname)
[docs] def save_to_kzip(self, fname):
trace_to_kzip(self, fname)
[docs] def add_offset(self, off):
off = np.atleast_2d(off)
self.coords.add_offset(off)
if len(self)<=self.uturn_detection_k:
kdt = utils.DynamicKDT(k=self.uturn_detection_k, n_jobs=1,
aniso_scale=self.aniso_scale)
for c in self.coords.data:
kdt.append(c)
self.kdt = kdt
else:
self.kdt = utils.DynamicKDT(points=self.coords.data,
k=self.uturn_detection_k, n_jobs=1,
aniso_scale=self.aniso_scale)
[docs] def append(self, coord, coord_cnn=None, grad=None, features=None):
self.coords.append(coord)
if len(self)>1:
diff = np.linalg.norm((coord - self.coords[-2]) * self.aniso_scale[0])
else:
diff = 5 # just guess
self.seg_length.append(diff)
self.runlengths.append(self.runlength)
if len(self) > self.uturn_detection_k+1:
distances, indices, coordinates = self.kdt.get_knn(coord, k=self.uturn_detection_k)
dist = distances.mean()
else:
dist = self.seg_length.ema * float(self.uturn_detection_k + 1) / 2
normalisation = self.seg_length.ema * float(self.uturn_detection_k + 1) / 2
self.dist_self.append([dist, dist/normalisation])
self.kdt.append(coord)
if self.skel:
dist, index, node = self.skel.get_closest_node(coord)
self.dist_skel.append(dist)
if grad is not None:
self.grads.append(grad)
if features is not None:
self.features.append(features)
if coord_cnn is not None:
self.coords_cnn.append(coord_cnn)
# Check for criteria
last_dist = self.dist_self[-self.uturn_detection_hold:, 1]
uturn = np.all(last_dist < self.uturn_detection_thresh)
self.uturn_mask.append(uturn)
if not self.uturn_occurred and uturn: # register the first u-turn
self.uturn_occurred = (len(self), self.runlength)
if not self.lost_track:
lost = self.dist_skel.max() > self.max_cutoff
if lost:
self.lost_track = (len(self), self.runlength)
[docs] def append_serial(self, *args):
for arg in zip(*args):
self.append(*arg)
@property
def avg_seg_length(self):
return self.seg_length.mean()
@property
def runlength(self):
return self.seg_length.sum()
@property
def avg_dist_skel(self):
return self.dist_skel.mean()
@property
def max_dist_skel(self):
return self.dist_skel.max()
@property
def avg_dist_self(self):
return self.dist_self.mean()
@property
def min_dist_self(self):
return self.dist_self.min()[0]
@property
def min_normed_dist_self(self):
return self.dist_self.min()[1]
[docs] def tortuosity(self, start=None, end=None):
if start is None:
start = 0
if end is None:
end = len(self)
arc = self.runlengths[end-1] - self.runlengths[start]
chord = np.linalg.norm((self.coords[end-1] - self.coords[start]) * self.aniso_scale[0])
t = arc / chord
return t
[docs] def plot(self, grads=True, skel=True, rand_color=False, fig=None):
if fig is None:
fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400))
if skel and self.skel:
fig = self.skel.plot_skel(fig=fig)
x = self.coords[:, 0]
y = self.coords[:, 1]
z = self.coords[:, 2] * self.aniso_scale[0, 2]
line_c = tuple(np.random.rand(3)) if rand_color else (0, 0, 0.7)
point_c = line_c if rand_color else (0.6, 0.7, 0.9)
mlab.plot3d(x, y, z, tube_radius=0.2, color=line_c, figure=fig)
mlab.points3d(x, y, z, scale_factor=0.8, color=point_c, figure=fig)
if grads and self.grads.length:
x = self.coords[:, 0]
y = self.coords[:, 1]
z = self.coords[:, 2] * self.aniso_scale[0, 2]
gx = -self.grads[:, 0]
gy = -self.grads[:, 1]
gz = -self.grads[:, 2] * self.aniso_scale[0, 2]
mlab.quiver3d(x, y, z, gx, gy, gz, figure=fig,
color=(0, 0.6, 0.2), scale_factor=3)
return fig
[docs] def split_uturns(self, return_accum_pathlength=False, print_stat=False):
transitions = np.diff(self.uturn_mask, axis=0, n=1)
transitions = np.nonzero(transitions)[0]
transitions[0::2] -= self.uturn_detection_hold # if add: end segment closer to uturn
transitions[1::2] -= self.uturn_detection_hold # if subtract: start new segment closer to uturn
transitions = np.minimum(np.maximum(0, transitions), len(self))
transitions = np.hstack((0, transitions, len(self)))
new_traces = []
accum_pathlenghts = []
accum_dist_skel = []
accum_runlength = 0
for i in range(0, len(transitions)-1, 2):
new = self.__class__(self.skel, self.aniso_scale[0, 2], self.max_cutoff,
self.uturn_detection_k, self.uturn_detection_thresh,
self.uturn_detection_hold)
start, stop = transitions[i], transitions[i+1]
if print_stat:
print("cutting between %i and %i " % (start, stop))
if start<stop: # some transitions are too short
coords = self.coords[start:stop]
new.append_serial(coords)
new_traces.append(new)
# accumulate pathlenghts and dist to skel over splits
runlengths = self.runlengths[start:stop]
runlengths = runlengths + accum_runlength - runlengths[0] # shift
accum_pathlenghts.append(runlengths)
accum_runlength = runlengths[-1]
if self.skel:
dist_skel = self.dist_skel[start:stop].copy()
if i > 0: # this makes the eval stop if the trace deviated from the
# skeleton too much during the uturn
max_dist_in_uturn = np.max(self.dist_skel[transitions[i-1]:start])
dist_skel[0] = np.maximum(dist_skel[0], max_dist_in_uturn)
accum_dist_skel.append(dist_skel)
else:
accum_dist_skel.append([])
if return_accum_pathlength:
return new_traces, np.hstack(accum_pathlenghts), np.hstack(accum_dist_skel)
else:
return new_traces
def normalised_min_dist(tr, point):
dist, ind, coord = tr.kdt.get_knn(point, k=1)
radius = tr.features[ind, 0]
return dist, dist / radius
def simple_stats(a):
m = np.mean(a, axis=0)
s = np.std(a, axis=0)
minv = np.min(a, axis=0)
maxv = np.max(a, axis=0)
return np.array([m, s, minv, maxv])
def radius_hist(r):
bins = np.array([0,8,14,23,35,50,80,200])
counts, bins = np.histogram(r, bins=bins, density=True)
return counts
def get_merge_features(main_tr, main_node, sub_tr, sub_node, end_match):
m_slice_small = slice(max(0,main_node-5), main_node+5)
m_slice_large = slice(max(0,main_node-25), main_node+25)
m_points = main_tr.coords[m_slice_small] * main_tr.aniso_scale
uu, pc_m, pc_dir_m = np.linalg.svd(m_points-m_points.mean(0))
m_feat_small = simple_stats(main_tr.features[m_slice_small])
m_feat_large = simple_stats(main_tr.features[m_slice_large])
m_radius_hist= radius_hist(main_tr.features[m_slice_large, 0])
m_tortuosity = main_tr.tortuosity()
main_features= np.hstack([pc_m, pc_dir_m.ravel(), m_feat_small.ravel(),
m_feat_large.ravel(), m_radius_hist, m_tortuosity])
if end_match:
s_slice_small = slice(max(0,sub_node-10), sub_node)
s_slice_large = slice(max(0,sub_node-50), sub_node)
else:
s_slice_small = slice(sub_node, sub_node+10)
s_slice_large = slice(sub_node, sub_node+50)
s_points = sub_tr.coords[s_slice_small] * sub_tr.aniso_scale
if len(s_points)==0:
pass
uu, pc_s, pc_dir_s = np.linalg.svd(s_points-s_points.mean(0))
s_feat_small = simple_stats(sub_tr.features[s_slice_small])
s_feat_large = simple_stats(sub_tr.features[s_slice_large])
s_radius_hist= radius_hist(sub_tr.features[s_slice_large, 0])
s_tortuosity = sub_tr.tortuosity()
sub_features = np.hstack([pc_s, pc_dir_s.ravel(), s_feat_small.ravel(),
s_feat_large.ravel(), s_radius_hist, s_tortuosity])
dist = np.linalg.norm((main_tr.coords[main_node]-sub_tr.coords[sub_node])*sub_tr.aniso_scale)
r_m = main_tr.features[main_node, 0]
r_s = sub_tr.features[sub_node, 0]
pc_dir_similarity = np.abs(np.dot(pc_dir_m, pc_dir_s.T))
joint_features = np.hstack([dist, dist/r_m, dist/r_s,
2*dist/(r_m+r_s), pc_dir_similarity.ravel()])
return main_features, sub_features, joint_features
def split_tree_components(tracetree, cut=False):
if tracetree.num_components==1:
return [tracetree,]
new_trees = [list() for i in range(tracetree.num_components)]
for tr_i in tracetree.traces:
c = tracetree.tr_i2comp_i[tr_i]
tr = tracetree.traces[tr_i]
cuts = tracetree.trace_cuts.get(tr_i, None)
if cuts and cut:
tr = tr.new_cut_trace(*cuts)
new_trees[c].append(tr)
for i in range(tracetree.num_components):
new_tree = TraceTree(new_trees[i],
tracetree.spine_thresh,
tracetree.endpoint_thresh)
new_trees[i] = new_tree
return new_trees
class TraceTree(object):
def __init__(self, traces, spine_thresh=1.5, endpoint_thresh=0.8):
"""
:param traces:
:param spine_thresh:
"""""":param spine_thresh: float
How large the maximal relative distance needs to be for a loop to be
retained as a spine branch
:param endpoint_thresh: float
Threshold of relative distance between endpoint and other trace tp
count as a connection
"""
# Rename trace keys to smaller contigious numbers
if not isinstance(traces, dict):
traces = dict(zip(range(len(traces)), traces))
self.traces = traces
self.trace_cuts = dict()
self.pruned_traces = []
self.edge_candidates = dict()
self.edges = []
self.tr_i2comp_i = None
self.num_components = 1
self.aniso = np.array([[1,1,2]])
self.spine_thresh = spine_thresh
self.endpoint_thresh = endpoint_thresh
self.joined_kdt = None
self.joined_coords = None
self.joined_radii = None
def build_joined_features(self):
self.joined_coords = np.vstack([tr.coords for tr in self.traces.values()])
self.joined_radii = np.hstack([tr.features[:,0] for tr in self.traces.values()])
kdt = utils.DynamicKDT(self.joined_coords, n_jobs=-1,
aniso_scale=[1, 1, 2], k=1)
self.joined_kdt = kdt
def cut_traces_inplace(self):
for tr_i in self.traces:
tr = self.traces[tr_i]
cuts = self.trace_cuts.get(tr_i, None)
if cuts:
new_tr = tr.new_cut_trace(*cuts)
self.traces[tr_i] = new_tr
def to_kzip(self, fname, save_loops=False, save_edge_candiates=False,
add_edges=False, save_edges=False):
fname = os.path.expanduser(fname)
fpath, comment_name = os.path.split(fname)
skel_objs = []
component_annos = []
for c in range(self.num_components):
skel_obj = knossos_skeleton.Skeleton()
skel_objs.append(skel_obj)
anno_ = knossos_skeleton.SkeletonAnnotation()
anno_.scaling = (9.0, 9.0, 20.0)
anno_.setComment(comment_name+"-c%i"%c)
skel_obj.add_annotation(anno_)
component_annos.append(anno_)
# Save all cut traces to own anno-obj of their component
node_mappings = dict()
for tr_i in self.traces:
if self.tr_i2comp_i is not None:
c = self.tr_i2comp_i[tr_i]
else:
c = 0
anno = component_annos[c]
tr = self.traces[tr_i]
cuts = self.trace_cuts.get(tr_i, None)
if cuts:
tr = tr.new_cut_trace(*cuts)
_, node_mapping = trace_to_anno(tr, fname, anno)
node_mappings[tr_i] = node_mapping
# Save all edges (between cut points) to a edge annotation
if len(self.edges) and save_edges:
edge_anno = knossos_skeleton.SkeletonAnnotation()
edge_anno.scaling = (9.0, 9.0, 20.0)
edge_anno.setComment("Edges-"+comment_name)
skel_obj_edges = knossos_skeleton.Skeleton()
skel_obj_edges.add_annotation(edge_anno)
for e in self.edges:
try:
main, sub = e
main_node_i, sub_node_i = self.edge_candidates[tuple(e)][1]
except KeyError: # MST might turn around edge order
main_node_i, sub_node_i = self.edge_candidates[tuple(e)[::-1]][1]
main, sub = e[::-1]
main_node = knossos_skeleton.SkeletonNode()
x,y,z = np.round(self.traces[main].coords[main_node_i]).astype(np.int16)
main_node.from_scratch(edge_anno, x,y,z)
edge_anno.addNode(main_node)
sub_node = knossos_skeleton.SkeletonNode()
x,y,z = np.round(self.traces[sub].coords[sub_node_i]).astype(np.int16)
sub_node.from_scratch(edge_anno, x,y,z)
edge_anno.addNode(sub_node)
edge_anno.addEdge(main_node, sub_node)
if add_edges:
main_cut = self.trace_cuts.get(main, [0, None])[0]
sub_cut = self.trace_cuts.get(sub, [0, None])[0]
main_i = main_node_i - main_cut
sub_i = sub_node_i - sub_cut
try:
n_main = node_mappings[main][main_i]
n_sub = node_mappings[sub][sub_i]
n_main.annotation.addEdge(n_main, n_sub)
except:
pass
outfile = fpath + "/edges-" + comment_name + '.k.zip'
skel_obj_edges.to_kzip(outfile)
# As Node==Edge==Node in one skeleton Tree (for making GT in knossos)
if save_edge_candiates and self.edge_candidates:
edge_candiate_anno = knossos_skeleton.SkeletonAnnotation()
edge_candiate_anno.scaling = (9.0, 9.0, 20.0)
edge_candiate_anno.setComment(comment_name+"-Edge-Candiates")
skel_obj_candidates = knossos_skeleton.Skeleton()
skel_obj_candidates.add_annotation(edge_candiate_anno)
for e in self.edge_candidates:
main, sub = e
main_node_i, sub_node_i = self.edge_candidates[e][1]
main_node = knossos_skeleton.SkeletonNode()
x,y,z = np.round(self.traces[main].coords[main_node_i]).astype(np.int16)
main_node.from_scratch(edge_candiate_anno, x,y,z)
main_node.setComment(comment_name+"-M%i_%i-S%i_%i-main"
%(main, main_node_i, sub, sub_node_i))
edge_candiate_anno.addNode(main_node)
sub_node = knossos_skeleton.SkeletonNode()
x,y,z = np.round(self.traces[sub].coords[sub_node_i]).astype(np.int16)
sub_node.from_scratch(edge_candiate_anno, x,y,z)
sub_node.setComment(comment_name+"-M%i_%i-S%i_%i-sub"
%(main, main_node_i, sub, sub_node_i))
edge_candiate_anno.addNode(sub_node)
edge_candiate_anno.addEdge(main_node, sub_node)
outfile = fpath + "/candidates-" + comment_name + '.k.zip'
skel_obj_candidates.to_kzip(outfile)
if save_loops:
skel_obj_loops = knossos_skeleton.Skeleton()
for i,t in enumerate(self.pruned_traces):
anno_loop, _ = trace_to_anno(t, comment_name+'-loop%i'%i)
skel_obj_loops.add_annotation(anno_loop)
outfile = fname + '-loops.k.zip'
skel_obj_loops.to_kzip(outfile)
for i,skel_obj in enumerate(skel_objs):
outfile = fname + '-c%i.k.zip' %i
skel_obj.to_kzip(outfile)
def is_loop(self, trace, traces):
"""
:param trace: Trace
test candidate
:param traces: list of Traces
:return: bool
"""
if not len(traces):
return False
# Determine the average distance of the two end points to all other
# traces. Use the distances normalised by the radius of the other trace
end_points = trace.coords[[0,-1]]
relative_distances = np.ones(len(traces)) * np.inf
for i,tr in enumerate(traces.values()):
if tr == trace: # don't compare trace to itself, would be loop always
continue
dist, ind, coord = tr.kdt.get_knn(end_points, k=1)
radii = tr.features[ind, 0]
relative_distances[i] = (dist/radii).mean()
k = relative_distances.argmin()
tr_i = traces.keys()[k]
if relative_distances[k] < self.endpoint_thresh:
#now check if there exists point that is farther away from main
dist, rel_dist = normalised_min_dist(traces[tr_i], trace.coords.data)
max_point = rel_dist.argmax()
if rel_dist[max_point] >= self.spine_thresh :
is_loop = False
cut_a = 0 if rel_dist[0] < rel_dist[-1] else len(rel_dist)
cut_b = max_point
cut_0 = min(cut_a, cut_b)
cut_1 = max(cut_a, cut_b)
assert cut_0 < cut_1
cuts = (cut_0, cut_1)
else:
is_loop = True
cuts = None
else:
is_loop = False
cuts = None
return is_loop, cuts
def closest_approach(self, tr_a, tr_b):
"""
:param tr_a: Trace
:param tr_b: Trace
:return:
"""
b0toa = normalised_min_dist(tr_a, tr_b.coords[0])[1]
b1toa = normalised_min_dist(tr_a, tr_b.coords[-1])[1]
a0tob = normalised_min_dist(tr_b, tr_a.coords[0])[1]
a1tob = normalised_min_dist(tr_b, tr_a.coords[-1])[1]
case = np.argmin([b0toa, b1toa, a0tob, a1tob])
geometrict_dist = np.min([b0toa, b1toa, a0tob, a1tob])
if geometrict_dist > self.spine_thresh :
return None
end_match = case in [1,3]
a_is_main = case in [0, 1]
main_tr = tr_a if case in [0,1] else tr_b
sub_tr = tr_b if case in [0,1] else tr_a
slice_20 = slice(-20, None) if end_match else slice(None, 20)
sub_coords = sub_tr.coords[slice_20]
distances, indices, coordinates = main_tr.kdt.get_knn(sub_coords, k=1)
max_seg_length = sub_tr.seg_length[slice_20].max()
sub_merge_candidates = (distances < max_seg_length).nonzero()[0]
if len(sub_merge_candidates):
sub_node = sub_merge_candidates[0] if end_match else sub_merge_candidates[-1]
else:
sub_node = distances.argmin()
main_node = indices[sub_node] # Take the main node which was found in knn
if end_match: # For end_match the index needs to be shifted by the trace length to comply with the indices of sub_coords
sub_node += len(sub_tr) - np.minimum(20, len(sub_tr))
cut_start = 0 if end_match else sub_node
cut_end = sub_node+1 if end_match else len(sub_tr)
assert cut_end-cut_start>0
# Check for cases where there is a spine loop
if (cut_start==sub_node and end_match):
assert cut_start==0
cut_end = len(sub_tr)
end_match = not end_match
elif (cut_end==sub_node and not end_match):
assert cut_end==len(sub_tr)
cut_start = 0
end_match = not end_match
cuts = (cut_start, cut_end)
nodes = (main_node, sub_node)
try:
feat = get_merge_features(main_tr,main_node,sub_tr,sub_node, end_match)
except:
pass
# merge_feat_main, merge_feat_sub, merge_feat_joint = feat
return geometrict_dist, a_is_main, nodes, cuts, feat
def make_merge_graph(self):
n = len(self.traces)
msd_list= []
keys = self.traces.keys()
# For all pairwise traces find closest approach, node/cuts indices and
# features for the edge classifier, collect positive edges in "msd_list"
for s in range(n):
for t in range(s+1, n):
tr_a = keys[s]
tr_b = keys[t]
tmp = self.closest_approach(self.traces[tr_a], self.traces[tr_b])
if tmp is None: # If Components are disconnected still add them
msd_list.append([tr_a, tr_a, 0])
msd_list.append([tr_b, tr_b, 0])
continue
geometrict_dist, a_is_main, nodes, cuts, feat = tmp
if np.isclose(geometrict_dist, 0.0):
geometrict_dist = 0.1
# otherwise connected components will consider this as split
if a_is_main:
main = tr_a
sub = tr_b
else:
main = tr_b
sub = tr_a
main_coord = self.traces[main].coords[nodes[0]]
sub_coord = self.traces[sub].coords[nodes[1]]
coords = (main_coord, sub_coord)
# if edge_classifier(feat) > thresh:
# don't add edges which are not classified
self.edge_candidates[(main, sub)] = [geometrict_dist, nodes, cuts, feat, coords]
msd_list.append([main, sub, geometrict_dist])
if len(msd_list)==0:
return
# Create MST from edge graph (actually MST-Forest)
main, sub, dist = np.array(msd_list).T
a = np.hstack([main, sub])
b = np.hstack([sub, main])
values = np.hstack([dist, dist])
adj_mat = sparse.csr_matrix(sparse.coo_matrix( (values, (a,b)) ))
mst = csgraph.minimum_spanning_tree(adj_mat)
edges_mst = np.array(mst.nonzero()).T
self.edges = edges_mst
# For all MST-edges update the cuts (cut as much as possible to cover merge positions)
for edge_mst in edges_mst:
try:
sub = edge_mst[1]
tmp = self.edge_candidates[tuple(edge_mst)]
except KeyError: # MST might turn around edge order
sub = edge_mst[0]
tmp = self.edge_candidates[tuple(edge_mst)[::-1]]
old_cuts = self.trace_cuts.get(sub, None)
if old_cuts is None:
new_cuts = tmp[2]
else:
new_cuts = (np.minimum(old_cuts[0], tmp[2][0]),
np.maximum(old_cuts[1], tmp[2][1]))
self.trace_cuts[sub] = new_cuts
# If edges candidates were classified negative, the components oft the
# MST-forest must be split
# Returns unconnected nodes (empty slots) as component too!
num, labels = csgraph.connected_components(mst, directed=False)
# Therefore select only the stuff which is in keys
comp_names, components = np.unique(labels[keys], return_inverse=True)
self.tr_i2comp_i = dict(zip(keys, components))
self.num_components = comp_names.size
def simplify(self, profile=False):
if profile:
tt = utils.Timer()
keep_traces = {}
pruned_traces = {}
traces = dict(self.traces)
keys = np.array(traces.keys())
trace_lengths = np.array([traces[tr_i].runlength for tr_i in keys])
keys_sorted = keys[np.argsort(trace_lengths)]
for tr_i in keys_sorted:
tr = traces[tr_i]
is_loop, cuts = self.is_loop(tr, traces)
if is_loop:
pruned_traces[tr_i] = traces.pop(tr_i)
else: # Dont cut traces here, it might mess up connection parts
#if cuts: # cut traces must be put to stack again because they might be a loop now
# cut_tr = tr.new_cut_trace(*cuts)
# traces[tr_i] = cut_tr
#else:
keep_traces[tr_i] = tr
self.pruned_traces = pruned_traces
self.traces = keep_traces
if profile:
tt.check(name='prune loops')
# If all traces are loops with another trace take the largest single trace
if len(self.traces)==0:
tr_lengths = np.array([(tr_i, tr.runlength) for tr_i,tr in pruned_traces.items()])
i = np.argmax(tr_lengths[:,1])
i = int(tr_lengths[i,0])
tr0 = pruned_traces.pop(i)
self.traces[i] = tr0
self.make_merge_graph()
if profile:
tt.check("merge graph")
def make_segment_lenghts(bone):
segment_lengths = np.linalg.norm(np.diff(bone, n=1, axis=0) * np.array([[1, 1, 2]]), axis=1)
segment_lengths = np.hstack(([0, ], segment_lengths))
segment_lengths[0] = segment_lengths[1] * 0.5
segment_lengths[1:-2] = (segment_lengths[1:-2] + segment_lengths[
2:-1]) * 0.5
segment_lengths[-1] = segment_lengths[-1] * 0.5
return segment_lengths
def runlength_metric(path_lengths, distances, cut_start=10, cut_max=200, num=50):
cutoffs = np.linspace(cut_start, cut_max, num=num)
runlengths = utils.AccumulationArray()
correct_lenghts = utils.AccumulationArray()
for cutoff in cutoffs:
for i in range(len(distances)):
larger = np.nonzero(distances[i] >= cutoff)[0]
if len(larger):
correct_lenghts.append(path_lengths[i][larger[0]-1])
else:
correct_lenghts.append(path_lengths[i][-1])
mean_correct_lenght = correct_lenghts.mean()
correct_lenghts.clear()
runlengths.append(mean_correct_lenght)
return runlengths.data, cutoffs
def runlength_metric_GT(trace, skel=None, cut_start=10, cut_max=200, num=20):
if skel is None:
skel = trace.skel
cutoffs = np.linspace(cut_start, cut_max, num=num)
runlengths = np.zeros(num)
trace_points = trace.coords.data
trace_kdt = utils.KDT(radius=cut_max, n_jobs=-1)
trace_kdt.fit(trace_points * np.array([[1,1,2]]))
for edge, bone in skel.bones.items():
segment_lengths = make_segment_lenghts(bone)
dist, ind = trace_kdt.radius_neighbors(bone * np.array([[1,1,2]]))
for i, cutoff in enumerate(cutoffs):
was_traced = np.zeros(len(bone), dtype=np.bool)
for k in range(len(dist)):
was_traced[k] = np.any(dist[k]<=cutoff)
runlengths[i] += segment_lengths[was_traced].sum()
return cutoffs, runlengths
def trace_to_anno(trace_xyz, name, anno=None, root=None):
if isinstance(trace_xyz, Trace):
feature_avail = len(trace_xyz.features)==len(trace_xyz)
else:
feature_avail = True
radius = 1.0
if anno is None:
anno = knossos_skeleton.SkeletonAnnotation()
anno.scaling = (9.0, 9.0 ,20.0)
anno.setComment(os.path.split(name)[1])
node_mapping = dict()
last_node = knossos_skeleton.SkeletonNode()
trace_coords = np.round(trace_xyz.coords).astype(np.int16) # trace_xyz.coords.data.astype(np.int16)
if feature_avail:
radius = trace_xyz.features[0, 0]
last_node.from_scratch(anno, trace_coords[0,0],
trace_coords[0,1], trace_coords[0,2], radius=radius)
if feature_avail:
last_node.setDataElem("axoness_proba0", trace_xyz.features[0, 1])
last_node.setDataElem("axoness_proba1", trace_xyz.features[0, 2])
last_node.setDataElem("axoness_proba2", trace_xyz.features[0, 3])
last_node.setDataElem("spiness_proba0", trace_xyz.features[0, 4])
last_node.setDataElem("spiness_proba1", trace_xyz.features[0, 5])
last_node.setDataElem("spiness_proba2", trace_xyz.features[0, 6])
anno.addNode(last_node)
node_mapping[0] = last_node
for k in range(1, len(trace_coords)):
coord = trace_coords[k]
if feature_avail:
radius = trace_xyz.features[k,0]
new_node = knossos_skeleton.SkeletonNode()
new_node.from_scratch(anno, coord[0], coord[1], coord[2], radius=radius)
if feature_avail:
last_node.setDataElem("axoness_proba0", trace_xyz.features[k, 1])
last_node.setDataElem("axoness_proba1", trace_xyz.features[k, 2])
last_node.setDataElem("axoness_proba2", trace_xyz.features[k, 3])
last_node.setDataElem("spiness_proba0", trace_xyz.features[k, 4])
last_node.setDataElem("spiness_proba1", trace_xyz.features[k, 5])
last_node.setDataElem("spiness_proba2", trace_xyz.features[k, 6])
node_mapping[k] = new_node
anno.addNode(new_node)
last_node.addChild(new_node)
last_node = new_node
if root is None:
if isinstance(trace_xyz, Trace) and trace_xyz.root is not None:
root = trace_xyz.root
if root is not None:
node_mapping[root].setRoot()
return anno, node_mapping
[docs]def trace_to_kzip(trace_xyz, fname):
skel_obj = knossos_skeleton.Skeleton()
anno, node_mapping = trace_to_anno(trace_xyz, fname)
skel_obj.add_annotation(anno)
outfile = fname + '.k.zip'
skel_obj.to_kzip(outfile)
def trace_to_kzip_multi(traces, fname):
if isinstance(traces, dict):
traces = traces.values()
skel_obj = knossos_skeleton.Skeleton()
for i, trace_xyz in enumerate(traces):
if not isinstance(trace_xyz, np.ndarray):
anno, node_mapping = trace_to_anno(trace_xyz, fname+"_%i"%i)
else:
trace_xyz = np.round(trace_xyz).astype(np.int16)
anno = knossos_skeleton.SkeletonAnnotation()
anno.scaling = (9.0, 9.0 ,20.0)
anno.setComment(os.path.split(fname)[1]+"_%i"%i)
last_node = knossos_skeleton.SkeletonNode()
last_node.from_scratch(anno, trace_xyz[0,0], trace_xyz[0,1], trace_xyz[0,2])
last_node.setRoot()
anno.addNode(last_node)
for coord in trace_xyz[1:]:
new_node = knossos_skeleton.SkeletonNode()
new_node.from_scratch(anno, coord[0], coord[1], coord[2])
anno.addNode(new_node)
last_node.addChild(new_node)
last_node = new_node
skel_obj.add_annotation(anno)
outfile = fname + '.k.zip'
skel_obj.to_kzip(os.path.expanduser(outfile))
def bbox_cube_anno(off_xyz, sz_xyz, comment="?", cross_edges=False):
off_xyz = np.array(off_xyz)
sz_xyz = np.array(sz_xyz)
cords = [off_xyz+sz_xyz*[0,0,0],
off_xyz+sz_xyz*[1,0,0],
off_xyz+sz_xyz*[1,1,0],#2
off_xyz+sz_xyz*[0,1,0],
off_xyz+sz_xyz*[0,0,1],#4
off_xyz+sz_xyz*[1,0,1],
off_xyz+sz_xyz*[1,1,1],#6
off_xyz+sz_xyz*[0,1,1],
]
cords = np.array(cords)
anno = knossos_skeleton.SkeletonAnnotation()
anno.scaling = (9.0, 9.0 ,20.0)
anno.setComment("%s: %s - %s"%(comment, off_xyz,sz_xyz))
nodes = []
for x,y,z in cords:
new_node = knossos_skeleton.SkeletonNode()
new_node.from_scratch(anno,x,y,z)
anno.addNode(new_node)
nodes.append(new_node)
edges = [(0,1),(0,3),(0,4), (1,2),(1,5), (3,2),(3,7), (4,5),(4,7), (6,7),(6,5),(6,2)]
for n1, n2 in edges:
nodes[n1].addChild(nodes[n2])
if cross_edges:
for n1 in anno.nodes:
for n2 in anno.nodes:
n1.addChild(n2)
return anno