Source code for elektronn2.data.skeleton

# -*- coding: utf-8 -*-
# ELEKTRONN2 Toolkit
# Copyright (c) 2015 Marius Killinger
# All rights reserved
from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, super, zip
# TODO: Python 3 compatibility

__all__ = ['trace_zyx2xyz', 'trace_to_kzip', 'SkeletonMFK',
           'Trace']

import os
import sys
from subprocess import check_call
import logging
from collections import OrderedDict

from scipy import interpolate
from scipy import sparse
from scipy.sparse import csgraph
import numpy as np
from knossos_utils import skeleton as knossos_skeleton  # TODO: Mark as dependency when knossos_utils is published.

from .. import utils

from ..config import config
from . import transformations

logger = logging.getLogger('elektronn2log')
inspection_logger = logging.getLogger('elektronn2log-inspection')




with open(os.devnull, 'w') as devnull:
    # mayavi is to dumb to raise an exception and instead crashes whole script....
    try:
        # "xset q" will always succeed to run if an X server is currently running
        check_call(['xset', 'q'], stdout=devnull, stderr=devnull)
        import mayavi.mlab as mlab
        # Don't set backend explicitly, use system default...
    # if "xset q" fails, conclude that X is not running
    except: # (OSError, ImportError, CalledProcessError, ValueError)
        logger.warning("No mayavi imported, cannot plot skeletons")
        mlab = None


###############################################################################

# Constants for scaling of radius
REF_RADIUS = 20.0
BASE = 1.3
BASE_I = BASE ** -1
HYST = 0.75  # 0.5: no memory 1.0 complete non-overlap
assert 0.5 <= HYST <= 1.0
BASE_H = BASE ** HYST
BASE_IH = BASE ** -HYST


@utils.timeit
@utils.my_jit(nopython=True)
def insert(cube, coords, i, off):
    for k in np.arange(coords.shape[0]):
        cube[coords[k,0]-off[0], coords[k,1]-off[1], coords[k,2]-off[2]] = i


@utils.timeit
@utils.my_jit(nopython=True)
def insert_vec(cube, coords, vec, off):
    n = len(coords)
    m = len(vec[0])
    double_inserts = 0
    for i in np.arange(n):
        for j in np.arange(m):
            if abs(cube[coords[i,0]-off[0],
                        coords[i,1]-off[1],
                        coords[i,2]-off[2],j])>1e-5:
                double_inserts += 1
                cube[coords[i,0]-off[0], coords[i,1]-off[1], coords[i,2]-off[2],j] = np.nan # in case of doubt, dont train here...
            else:
                cube[coords[i,0]-off[0], coords[i,1]-off[1], coords[i,2]-off[2],j] = vec[i, j]

    return double_inserts


@utils.timeit
@utils.my_jit(nopython=True)
def ray_cast(max_dists, hull_points, hull_dist, ray_steps, hull_cube, off):
    s = np.float32(0.9) # step length
    sh = hull_cube.shape
    for i in np.arange(len(hull_points)): # take hull point
        # initialise dist and position
        dist = hull_dist[i] + 1e-5
        x = hull_points[i, 0] - off[0]
        y = hull_points[i, 1] - off[1]
        z = hull_points[i, 2] - off[2]
        found = False
        count = 0
        while True:
            count += 1
            x = x + s * ray_steps[i,0]
            y = y + s * ray_steps[i,1]
            z = z + s * ray_steps[i,2]
            if np.int(x)<0.0 or np.int(y)<0.0 or np.int(z)<0.0:
                break
            if np.int(x+0.5)>=sh[0] or np.int(y+0.5)>=sh[1] or np.int(z+0.5)>=sh[2]:
                break
            # search if hull is True in neighbourhood of x,y,z
            found = hull_cube[np.int(x), np.int(y), np.int(z)] or \
                    hull_cube[np.int(x+0.5), np.int(y), np.int(z)] or \
                    hull_cube[np.int(x), np.int(y+0.5), np.int(z)] or \
                    hull_cube[np.int(x), np.int(y), np.int(z+0.5)] or \
                    hull_cube[np.int(x+0.5), np.int(y+0.5), np.int(z)] or \
                    hull_cube[np.int(x+0.5), np.int(y), np.int(z+0.5)] or \
                    hull_cube[np.int(x), np.int(y+0.5), np.int(z+0.5)] or \
                    hull_cube[np.int(x+0.5), np.int(y+0.5), np.int(z+0.5)]
            if not found:
                break
            if count>200:
                break

            dist = dist + s

        max_dists[i] = dist

@utils.my_jit(nopython=True, cache=True)
def find_peaks_helper(padded_cube, peak_cube):
    sh = padded_cube.shape
    for z in np.arange(1,sh[0]-1):
        for x in np.arange(1,sh[1]-1):
            for y in np.arange(1,sh[2]-1):
                center = padded_cube[z,x,y]
                is_peak = center >= padded_cube[z-1, x-1, y-1] and \
                          center >= padded_cube[z+1, x+1, y+1] and \
                          center >= padded_cube[z+1, x+1, y-1] and \
                          center >= padded_cube[z+1, x-1, y+1] and \
                          center >= padded_cube[z-1, x+1, y+1] and \
                          center >= padded_cube[z+1, x-1, y-1] and \
                          center >= padded_cube[z-1, x-1, y+1] and \
                          center >= padded_cube[z-1, x+1, y-1] and \
                          center >= padded_cube[z, x, y+1] and \
                          center >= padded_cube[z, x, y-1] and \
                          center >= padded_cube[z, x+1, y] and \
                          center >= padded_cube[z, x-1, y] and \
                          center >= padded_cube[z+1, x, y] and \
                          center >= padded_cube[z-1, x, y] and \
                          center >= padded_cube[z, x+1, y+1] and \
                          center >= padded_cube[z, x-1, y-1] and \
                          center >= padded_cube[z+1, x+1, y] and \
                          center >= padded_cube[z-1, x-1, y] and \
                          center >= padded_cube[z+1, x, y+1] and \
                          center >= padded_cube[z-1, x, y-1] and \
                          center >= padded_cube[z, x-1, y+1] and \
                          center >= padded_cube[z, x+1, y-1] and \
                          center >= padded_cube[z+1, x-1, y] and \
                          center >= padded_cube[z+1, x-1, y] and \
                          center >= padded_cube[z-1, x, y+1] and \
                          center >= padded_cube[z-1, x, y+1]
                if is_peak:
                    peak_cube[z-1,x-1,y-1] = center

def find_peaks(cube):
    padded_cube = np.pad(cube, 1, mode='constant')
    peaks = np.zeros_like(cube)
    find_peaks_helper(padded_cube, peaks)
    #peak_label, n = ndimage.label(peaks)
    #coordinates = ndimage.center_of_mass(peaks, peak_label, index=np.arange(1,n+1))
    indices = np.flatnonzero(peaks)
    maxima = cube.ravel()[indices]
    sort_ix = np.argsort(maxima)
    return indices[sort_ix], maxima[sort_ix]



# WARNING / NOTE: skeleton objects are in xyz-order
[docs]class SkeletonMFK(object): """ Joints: all branches and end points / node terminatons (nodes not of deg 2) Branches: Joints of degree >= 3 """ @staticmethod
[docs] def find_joints(node_list): joints = {} branches = {} for node in node_list: if node.degree() > 2: # branching point joints[node.ID] = node branches[node.ID] = node if node.degree()==1: joints[node.ID] = node # end point return joints, branches
def __init__(self, aniso_scale=2, name=None, skel_num=None): self.aniso_scale = np.array([[1,1,aniso_scale]], dtype=np.float32) self.bones = dict() self.edges = list() self.branches = dict() self.joints = dict() self.all_nodes = None self.hull_points = None self.hull_skel = dict() self.hull_branch = dict() self.name = name self.skel_num = skel_num self.radii = dict() self.all_radii = None self.joint_radii = None self.props = dict() self.all_props = None self.joint_props = None self.joint_id2joint_index = dict() # For training self.kdt_hull = None self.linked_data = None self.lost_track = False self.position_s = None self.position_l = None self.direction_il = None self.start_new_training = True self.prev_batch = None self.trafo = None self.prev_scale = 1.0 self.prev_gamma = 0.0 self.training_traces = [] self.background_processes = False self._hull_point_bg = dict() self.cnn_grid = None # Old for training self.debug_traces = [] self.debug_traces_current = [] self.debug_grads = [] self.debug_grads_current = []
[docs] def init_from_annotation(self, skeleton_annotatation, min_radius=None, interpolation_resolution=0.5, interpolation_order=1): # Read annotation data structures and convert to dicts and np.ndarrays #print(len(skeleton_annotatation.getNodes())) self.joints, self.branches = self.find_joints(skeleton_annotatation.getNodes()) #print(len(self.joints), len(self.branches)) visited = {n: False for n in skeleton_annotatation.getNodes()} for joint_id, joint in self.joints.items(): directions = joint.getNeighbors() for d in directions: if visited[d]: # we have visited this bone already continue visited[d] = True # mark as visited bone = OrderedDict() # create new bone bone[joint] = True # start the bone at the joint current_node = d # next go to the node in the selected direction while True: bone[current_node] = True if current_node.degree() > 2 or current_node.degree()==1: # At new branch or end point, the bone ends here # add edge between starting joint and this branch if joint_id < current_node.ID: edge = (joint_id, current_node.ID) else: edge = (current_node.ID, joint_id) self.edges.append(edge) break else: # The node has 2 neibgs, one from which we come and # another one to which we go nb = list(current_node.getNeighbors()) assert len(nb) == 2 # Test which node we visit next if nb[0] in bone: assert nb[1] not in bone current_node = nb[1] if nb[1] in bone: assert nb[0] not in bone current_node = nb[0] self.bones[edge] = list(bone.keys()) # Convert bones to arrays for edge, bone in self.bones.items(): self.bones[edge] = np.array([x.getCoordinate() for x in bone], dtype=np.float32) self.radii[edge] = np.array([x.getDataElem('radius') for x in bone], dtype=np.float32) try: axoness_pred = np.array([x.getDataElem('axoness_pred') for x in bone], dtype=np.int16) spiness_pred = np.array([x.getDataElem('spiness_pred') for x in bone], dtype=np.int16) props = np.concatenate([axoness_pred[:,None], spiness_pred[:,None]], axis=1) self.props[edge] = props except KeyError: pass # convert joints to arrays self.joint_radii = np.array([x.getDataElem('radius') for x in self.joints.values()], dtype=np.float32) try: axoness_pred = np.array([x.getDataElem('axoness_pred') for x in self.joints.values()], dtype=np.int16) spiness_pred = np.array([x.getDataElem('spiness_pred') for x in self.joints.values()], dtype=np.int16) self.joint_props = np.concatenate([axoness_pred[:,None], spiness_pred[:,None]], axis=1) except KeyError: pass self.joint_id2joint_index = dict(zip(self.joints.keys(), range(len(self.joints)))) self.joints = np.array([x.getCoordinate() for x in self.joints.values()], dtype=np.float32) # convert branches to arrays self.branches = np.array([x.getCoordinate() for x in self.branches.values()], dtype=np.float32) if interpolation_resolution is not None: for edge, bone in self.bones.items(): if len(bone)<=1: continue try: new_bone = self.interpolate_bone(bone,max_k=interpolation_order, resolution=interpolation_resolution) self.radii[edge] = self.interpolate_prop(bone, self.radii[edge], new_bone) except: bone, keep_index = utils.unique_rows(bone) new_bone = self.interpolate_bone(bone,max_k=interpolation_order, resolution=interpolation_resolution) self.radii[edge] = self.interpolate_prop(bone, self.radii[edge][keep_index], new_bone) try: self.props[edge] = self.interpolate_prop(bone, self.props[edge],new_bone, discrete=True) except: pass self.bones[edge] = new_bone self.all_nodes = np.vstack([self.joints,] + list(self.bones.values())) self.all_radii = np.hstack([self.joint_radii,] + list(self.radii.values())) try: self.all_props = np.vstack([self.joint_props, ] + list(self.props.values())) except: pass if min_radius: self.all_radii = np.maximum(self.all_radii, min_radius)
[docs] def save(self, fname): utils.picklesave(self, fname)
[docs] def interpolate_bone(self, bone, max_k=1, resolution=0.5): bone_iso = bone * self.aniso_scale linear_distances = np.linalg.norm(np.diff(bone_iso, axis=0), axis=1) total_dist = linear_distances.sum() k = min(max_k, bone_iso.shape[0]-1) tck, u = interpolate.splprep(bone_iso.T, k=k) n = max(2, int(float(total_dist) / resolution)) new = interpolate.splev(np.linspace(0,1,n), tck) new = np.array(new).T / self.aniso_scale return new
[docs] def interpolate_prop(self, old_bone, old_prop, new_bone, discrete=False): dtype = np.int16 if discrete else np.float32 new_prop = np.zeros((len(new_bone),)+old_prop.shape[1:], dtype=dtype) old_bone_iso = old_bone * self.aniso_scale new_bone_iso = new_bone * self.aniso_scale start_i = 0 stop_i = 1 min_dist = np.linalg.norm(new_bone_iso[0] - old_bone_iso[stop_i]) for i in range(len(new_bone)): dist_start = np.linalg.norm(new_bone_iso[i] - old_bone_iso[start_i]) dist_stop = np.linalg.norm(new_bone_iso[i] - old_bone_iso[stop_i]) min_dist = min(min_dist, dist_stop) if (min_dist < dist_stop) and stop_i+1<len(old_bone): stop_i += 1 start_i += 1 dist_start = dist_stop dist_stop = np.linalg.norm(new_bone_iso[i] - old_bone_iso[stop_i]) min_dist = dist_stop if discrete: if dist_stop > dist_start: new_prop[i] = old_prop[start_i] else: new_prop[i] = old_prop[stop_i] else: d = dist_start + dist_stop new_prop[i] = dist_stop/d * old_prop[start_i] + dist_start/d * old_prop[stop_i] return new_prop
@utils.cache()
[docs] def get_kdtree(self, static_points, k=1, jobs=-1): kdt = utils.KDT(n_neighbors=k, n_jobs=jobs, algorithm='kd_tree', leaf_size=20) kdt.fit(static_points * self.aniso_scale) # change metric) #assert np.all(kdt._fit_X / self.aniso_scale == static_points) return kdt
@utils.cache()
[docs] def get_knn(self, kdt, query_points, k=None): if k is not None: pass #assert k==kdt.n_neighbors else: k = kdt.n_neighbors distances, indices = kdt.kneighbors(query_points * self.aniso_scale, n_neighbors=k) # change metric) static_points = kdt._fit_X.astype(np.float32) # Attention those still have the aniso scale in [:,2] if k==1: indices = indices[:,0] distances = distances[:,0].astype(np.float32) coordinates = static_points[indices] / self.aniso_scale # change to pixel coordinates else: distances = distances.astype(np.float32) coordinates = static_points[indices] / self.aniso_scale # change to pixel coordinates assert coordinates.shape[1] == k return distances, indices, coordinates
[docs] def get_closest_node(self, position_s): kdt = self.get_kdtree(self.all_nodes, k=1, jobs=1) dist, ind, nearest_s = self.get_knn(kdt, position_s) if position_s.ndim==1: dist = dist[0] ind = ind[0] nearest_s = nearest_s[0] return dist.astype(np.float32), ind, nearest_s
### Sampling routines for getting training data ###
[docs] def sample_skel_point(self, rng, joint_ratio=None): n = len(self.all_nodes) if joint_ratio: if rng.rand() < joint_ratio: n = len(self.joints) i = rng.randint(n) node = self.all_nodes[i] return node, i
[docs] def sample_tube_point(self, rng, r_max_scale=0.9, joint_ratio=None): """ This is skeleton node based sampling: Go to a random node, sample a random orthogonal direction go a random distance into direction (uniform over the [0, r_max_scale * local maximal radius]) """ # tt = utils.Timer() if self.hull_points is None: kdt = None else: if self.kdt_hull is None: raise RuntimeError("Hull kdts must be pre initialised") kdt = self.kdt_hull node, node_i = self.sample_skel_point(rng, joint_ratio) direc_iso = self.sample_local_direction_iso(node) local_r = self.all_radii[node_i] * r_max_scale count = 0 max_count = 30 proposal = node clipped = False while True: r = rng.rand() * local_r phi = rng.rand() * 2 * np.pi cos_theta = rng.rand() * 2 - 1 sin_theta = np.sqrt(1 - cos_theta ** 2) x = np.cos(phi) * sin_theta y = np.sin(phi) * sin_theta z = cos_theta rand_vec = np.array([x, y, z]) orthogonal_vec_iso = np.cross(direc_iso, rand_vec) orthogonal_vec_iso /= np.linalg.norm(orthogonal_vec_iso) orthogonal_vec = orthogonal_vec_iso / self.aniso_scale[0] proposal = node + orthogonal_vec * r if kdt is None: return proposal dist, ind, coord = self.get_knn(kdt, proposal) dist = dist[0] if dist < 1.5: # we are within hull: break if count >= max_count / 2 and not clipped: local_r *= 0.5 clipped = True logger.debug("Sample hull point: clipped r") if count >= max_count: logger.debug( "Sample hull point: max count %i reached" % max_count) proposal = node break # tt.check("\tdouble_check") count += 1 return proposal
[docs] def sample_local_direction_iso(self, point, n_neighbors=6): """ For a point gives the local skeleton direction/orientation by fitting a line through the nearest neigbours, sign is randomly assigned """ kdt = self.get_kdtree(self.all_nodes, k=n_neighbors, jobs=1) dist, ind, coord = self.get_knn(kdt, point) dist = dist[0] ind = ind[0] coord = coord[0] # maybe use dist as weights for svd? neibs_iso = coord * self.aniso_scale # transform to iso space uu, dd, vv = np.linalg.svd(neibs_iso - neibs_iso.mean(axis=0)) direc_iso = vv[0] # take largest eigenvector direc_iso /= np.linalg.norm(direc_iso, axis=0) # normalise return direc_iso
[docs] def sample_tracing_direction_iso(self, rng, local_direction_iso, c=0.5): """ Sample a direction close to the local direction there is a prior so that the normalised (0,1) angle of deviation a has this distribution: p(a) = 1/N * (1-c*a), where N= 1 - c/2, tmp is the inverse cdf of this shit """ if rng.rand() > 0.5: # the sign is undefined, choose randomly local_direction_iso *= -1 u = rng.rand() tmp = (1 - np.sqrt(1-(2*c - c**2)*u)) / c # theta scaled between 0 and 1 # theta scaled between 0 and 90 deg in rad i.e. 0 and pi/2 theta = tmp * 0.5 * np.pi max_count = 1000 count = 0 proposal = local_direction_iso while True: proposal = rng.rand(3) * 2 - 1 proposal /= np.linalg.norm(proposal, axis=0) # normalise cos_alpha = np.dot(proposal, local_direction_iso) if cos_alpha < 0: # flip to next best within +/- 90 deg cos_alpha *= -1 proposal *= -1 alpha = np.arccos(cos_alpha) if alpha < theta + 0.01: break count += 1 if count>max_count: logger.debug("Sample tracing directions: max count reached") break return proposal
### Loss and loss gradient for Theano Graph ###
[docs] def get_loss_and_gradient(self, new_position_s, cutoff_inner=1.0/3, rise_factor=0.1): """ prediction_c (zxy) Zoned error surface: flat in inner hull (selected at cutoff_inner) constant gradient in "outer" hull towards nearest inner hull voxel gradient increasing with distance (scaled by rise_factor) for predictions outside hull """ inner_hull, indices = self.get_hull_points_inner(cutoff_inner, return_indices=True) kdt = self.get_kdtree(inner_hull, k=1, jobs=1) dist, ind, nearest_s = self.get_knn(kdt, new_position_s) dist = dist[0] ind = ind[0] nearest_s = nearest_s[0] if config.inspection: inspection_logger.info("nearest_s: %s"% (nearest_s.tolist())) if dist<1.5: # we are within inner hull. The maximal distance if # within hull is excatly: np.linalg.norm(np.multiply( # [0.5, 0.5, 0.6], [1, 1, 2])) = 1.22... --> add some margin loss = 0.0 grad_s = np.zeros((3,), dtype=np.float32) self.lost_track = False else: loss = dist # max dist of closest node max_dist = self.hull_skel['max_dist'][indices[ind]] # pointing from nearest to new position unit_grad = (new_position_s - nearest_s) unit_grad /= np.linalg.norm(unit_grad * self.aniso_scale[0], axis=0) if max_dist > dist: # we are in hull but not in inner tube grad_s = unit_grad * 1.0 self.lost_track = False else: # we are outside hull self.lost_track = True factor = rise_factor * (dist - max_dist) grad_s = unit_grad * (1 + factor) self.debug_traces_current.append(new_position_s) self.debug_grads_current.append(grad_s) loss = np.array([loss,], dtype=np.float32) return loss, grad_s
def _new_training_trace(self, **get_batch_kwargs): """ Preprate skeleton for a new training (sample location/direction, reset stuff) Parameters ---------- get_batch_kwargs """ #tt = utils.Timer() if self.current_trace: if len(self.training_traces)>20: self.training_traces = self.training_traces[-2:] self.training_traces.append(self.current_trace) self.current_trace = Trace(linked_skel=self) r_max_scale = get_batch_kwargs['r_max_scale'] tracing_dir_prior_c = get_batch_kwargs['tracing_dir_prior_c'] joint_ratio = get_batch_kwargs.get('joint_ratio', None) position_s = self.sample_tube_point(self.linked_data.rng, r_max_scale=r_max_scale, joint_ratio=joint_ratio) if config.inspection: inspection_logger.info("Start new training") local_direc_is = self.sample_local_direction_iso(position_s, n_neighbors=6) tracing_direc_is = self.sample_tracing_direction_iso(self.linked_data.rng, local_direc_is, c=tracing_dir_prior_c) self.position_s = position_s self.position_l = position_s[::-1] # from lab2data (xyz)->(zxy) self.direction_il = tracing_direc_is[::-1] # from lab2data (xyz)->(zxy) self.current_trace.append(position_s, coord_cnn=[0,]*3, grad=[0,]*3, features=[0,]*7) self.lost_track = False self.trafo = None #tt.check("final") @staticmethod
[docs] def get_scale_factor(radius, old_factor, scale_strenght): """ Parameters ---------- radius: predicted radius (not the true radius) old_factor: factor by which the radius prediction and the image was scaled scale_strenght: limits the maximal scale factor Returns ------- new_factor """ # if old was large (zoom in), radius is smaller hi = BASE ** (scale_strenght * 2) + 1e-3 # e.g 1.69 for 1.3**2 lo = BASE_I ** (scale_strenght * 4) - 1e-3 # e.g. 0.35 for 1/1.3 ** 4 radius_true = radius / old_factor new_factor = REF_RADIUS / radius_true new_factor = np.clip(new_factor, lo, hi) change = new_factor / old_factor if new_factor > 1.0: # left side if change >= BASE_H: # growing new_factor = old_factor * BASE elif change < BASE_IH: new_factor = old_factor * BASE_I else: new_factor = old_factor elif new_factor < 1.0: # right side if change <= BASE_IH: # zoom out new_factor = old_factor * BASE_I elif change > BASE_H: # zoome in new_factor = old_factor * BASE else: new_factor = old_factor else: new_factor = old_factor if config.inspection: inspection_logger.info("SCALE: %.2f -> %.2f, factor0: %.2f, factor: %.2f" % (radius, radius_true, 20.0 / radius_true, new_factor)) return new_factor
@staticmethod @utils.cache def make_grid(t_grid_sh, z_shift): """ Parameters ---------- t_grid_sh: tagged shape (pixel shape + strides) z_shift: shift of center (positive means more look ahead) Returns ------- points: coordinate list zyx order zz,yy,xx: coordinate meshgrid """ sh = np.array(t_grid_sh.spatial_shape) st = np.array(t_grid_sh.strides) lim = (sh-1) * st + 1 lim //= 2 zz,yy,xx = np.mgrid[-lim[0]:lim[0]:1j * sh[0], -lim[1]:lim[1]:1j * sh[1], -lim[2]:lim[2]:1j * sh[2]] zz += z_shift points = np.hstack([zz.ravel()[:,None], yy.ravel()[:,None], xx.ravel()[:,None]]).astype(np.float32) return points, zz,yy,xx @staticmethod
[docs] def point_potential(r, margin_scale, size, repulsion=None): if repulsion is None: repulsion = 1.0 left = margin_scale * size x = (r - left)/(size - left) v = 1.0 - (x**3*(x*(x*6 - 15) + 10)) # soft step function v = np.minimum(np.maximum(v, 0.0), 1.0) return v * repulsion
[docs] def getbatch(self, prediction, scale_strenght, **get_batch_kwargs): """ Parameters ---------- prediction: [[new_position_c, radius, ]] scale_strenght: limits the maximal scale factor for zoom get_batch_kwargs Returns ------- batch: img, target_img, target_grid, target_node """ get_batch_kwargs = dict(get_batch_kwargs) # copy because we destory it if self.start_new_training: self._new_training_trace(**get_batch_kwargs) self.start_new_training = False scale = 1.0 self.prev_scale = 1.0 self.prev_gamma = np.random.rand() * 2 * np.pi elif np.allclose(prediction, 0): scale = self.prev_scale if config.inspection: inspection_logger.warning("getbatch with no feedback: either " "training on same skel or error") else: prediction = prediction[0] new_position_c = prediction[:3] radius = prediction[3] # this is just the predicted val, not the true new_position_l, tracing_direc_il = self.trafo.cnn_pred2lab_position(new_position_c) new_position_s = new_position_l[::-1] self.position_s = new_position_s self.position_l = new_position_s[::-1] # from lab2data (xyz)->(zxy) self.direction_il = tracing_direc_il scale = self.get_scale_factor(radius, self.prev_scale, scale_strenght) self.prev_scale = scale grid = get_batch_kwargs.pop('grid', False) t_grid_sh = get_batch_kwargs.pop('t_grid_sh', None) z_shift = get_batch_kwargs['z_shift'] get_batch_kwargs.pop('joint_ratio', None) try: if config.inspection: inspection_logger.info("Getslice from position_l %s in " "direction_il %s, SCALE %.2f"%(np.array_str( self.position_l, precision=1, suppress_small=True), self.direction_il, scale)) get_batch_kwargs['gamma'] = self.prev_gamma data_batch = self.linked_data.get_newslice(self.position_l, self.direction_il, scale=scale, **get_batch_kwargs) img, target_img, trafo = data_batch[:3] if grid: raise RuntimeError("The creation of the grid target must" "be testet for spatial coherence again") if not self.cnn_grid: self.cnn_grid = self.make_grid(t_grid_sh, z_shift) grid_coords_c, zz, yy, xx = self.cnn_grid #dir_point_s = self.position_s + self.direction_il[::-1]/self.aniso_scale[0] #dir_momentum_s = dir_point_s - self.current_trace.coords[-4:].mean(0) #dir_momentum_ci = trafo.lab_coord2cnn_coord(dir_momentum_s[::-1])*[2,1,1] #directions_ci = grid_coords_c*[2,1,1] #direction_difference = cdist(directions_ci, dir_momentum_ci[None], 'cosine') # 0..2, 45deg thresh: > 1.7 #direction_difference = (direction_difference[:,0] - 1.7).astype(np.float32) #direction_difference[direction_difference<0.0] = 0.0 #direction_difference[np.isnan(direction_difference)] = 0.0 # center if even is NULL #repulsion = 1.0 - direction_difference * 2 # * strength, without factor ~ -25% ### TODO might also make repulsion depending on skel_node instead of grid_position. No WHY? ### TODO repulsion is not smooth enough repulsion = 1.0 grid_coords = trafo.cnn_coord2lab_coord(grid_coords_c,add_offset_l=True) dist, ind, nearest_s = self.get_closest_node(grid_coords[:,::-1]) radii = self.all_radii[ind] target_grid = self.point_potential(dist, 0.1, radii, repulsion) target_grid = target_grid.reshape(zz.shape)[None] # add channel if np.allclose(target_grid, 0.0): logger.warning("WTF") self.debug_shit = [img, target_grid] self.debug_shit2 = [nearest_s, radii] else: target_grid = np.ones((1,1,1,1), dtype=np.float32) # Get bio labels/classes dist, ind, nearest_s = self.get_closest_node(self.position_s) classes = self.all_props[ind] target_node = np.zeros(7, dtype=np.float32) target_node[classes[0]+1] = 1 target_node[classes[1]+4] = 1 target_node[0] = self.all_radii[ind] * scale if config.inspection: inspection_logger.info("target_node %s, (true r: %.1f)" %(target_node, self.all_radii[ind])) batch = (img, target_img, target_grid, target_node) self.trafo = trafo return batch except transformations.WarpingOOBError: if config.inspection: inspection_logger.info("OOB in getbatch") raise transformations.WarpingOOBError("Batch OOB")
[docs] def step_feedback(self, new_position_s, new_direction_is, pred_c, pred_features, cutoff_inner=1.0/3, rise_factor=0.1): inner_hull, indices = self.get_hull_points_inner(cutoff_inner, return_indices=True) kdt = self.get_kdtree(inner_hull, k=1, jobs=1) dist, ind, nearest_s = self.get_knn(kdt, new_position_s) dist = dist[0] ind = ind[0] nearest_s = nearest_s[0] # we are within inner hull. The maximal distance if within hull is 1.2... if dist < 1.5: loss = 0.0 grad_s = np.array([0, 0, 0], dtype=np.float32) else: loss = dist max_dist = self.hull_skel['max_dist'][indices[ind]] # max dist of closest node unit_grad = (new_position_s - nearest_s) # pointing from nearest to new position unit_grad /= np.linalg.norm(unit_grad * self.aniso_scale[0], axis=0) # normalise grad if max_dist > dist: # we are in hull but not in inner tube grad_s = unit_grad * 1.0 else: # we are outside hull self.lost_track = True if config.inspection: inspection_logger.info("Lost track") factor = rise_factor * (dist - max_dist) grad_s = unit_grad * (1 + factor) self.current_trace.append(new_position_s, coord_cnn=pred_c, grad=grad_s, features=pred_features) # Actually the new positions should be set in getbach, but we need to # set them here to because sometimes getbatch might be called without # "start_new_training" and with only zeros as prediction self.position_s = new_position_s self.position_l = new_position_s[::-1] # from lab to data frame (xyz) -> (zxy) self.direction_il = new_direction_is[::-1] # from lab to data frame (xyz) -> (zxy) loss = np.array([loss,], dtype=np.float32) return loss, grad_s, nearest_s
[docs] def step_grid_update(self, grid, radius, bio): pred_features = np.hstack([radius, bio]) flat_indices, scores = find_peaks(grid[0,0]) grid_coords_c, zz, yy, xx = self.cnn_grid preds_c = grid_coords_c[flat_indices] #preds_l = self.trafo.cnn_coord2lab_coord(preds_c,add_offset_l=True) if len(scores): new_position_c = grid_coords_c[flat_indices[-1]] preds_c = new_position_c[None] else: new_position_c = np.array([2,0,0], dtype=np.float32) new_position_l, tracing_direc_il = self.trafo.cnn_pred2lab_position(new_position_c) new_position_s = new_position_l[::-1] new_direction_is = tracing_direc_il[::-1] if config.inspection: inspection_logger.info("GridUpdate, node pred %s" % ( np.array_str(pred_features, precision=2, suppress_small=True),)) inspection_logger.info( "GridUpdate, new_position_c: %s, new_position_l: %s" % ( new_position_c, np.array_str(new_position_l, precision=1, suppress_small=True))) if config.inspection>1: img, grid_t = self.debug_shit utils.picklesave([img[0,0], grid_t[0], grid[0,0]], '/tmp/debug_shit_skel_%i' %self.skel_num) self.current_trace.append(new_position_s, coord_cnn=new_position_c, features=pred_features) # Actually the new positions should be set in getbach, but we need to # set them here to because sometimes getbatch might be called without # "start_new_training" and with only zeros as prediction self.position_s = new_position_s self.position_l = new_position_s[::-1] # from lab to data frame (xyz) -> (zxy) self.direction_il = new_direction_is[::-1] # from lab to data frame (xyz) -> (zxy) return new_position_c[None], preds_c, scores
### Plotting ###
[docs] def plot_skel(self, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.all_nodes[:,0] y = self.all_nodes[:,1] z = self.all_nodes[:,2]*self.aniso_scale[0,2] mlab.points3d(x,y,z, scale_factor=0.8, color=(1,0,0), figure=fig) for bone in self.bones.values(): x = bone[:,0] y = bone[:,1] z = bone[:,2]*self.aniso_scale[0,2] mlab.plot3d(x,y,z,tube_radius=0.4, color=(0.3,0.3,0.3), figure=fig) self._plot_joints(fig=fig) return fig
[docs] def plot_debug_traces(self, grads=True, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) traces = np.array(self.debug_traces) for trace in traces: x = trace[:, 0] y = trace[:, 1] z = trace[:, 2] * self.aniso_scale[0, 2] mlab.plot3d(x, y, z, tube_radius=0.2, color=(0.3, 0.3, 0.3), figure=fig) if grads: grads = np.array(self.debug_grads) for grad, trace in zip(grads, traces): x = trace[:, 0] y = trace[:, 1] z = trace[:, 2] * self.aniso_scale[0, 2] gx = -grad[:, 0] gy = -grad[:, 1] gz = -grad[:, 2] * self.aniso_scale[0, 2] mlab.quiver3d(x,y,z, gx, gy, gz, figure=fig, color=(0,0.6,0.2), scale_factor=3) return fig
[docs] def plot_radii(self, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.all_nodes[:,0] y = self.all_nodes[:,1] z = self.all_nodes[:,2]*self.aniso_scale[0,2] r = self.all_radii mlab.points3d(x,y,z,r, scale_mode='scalar', scale_factor=1, color=(0,0.5,0.5), mode='sphere', opacity=0.1, figure=fig) return fig
def _plot_joints(self, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.joints[:,0] y = self.joints[:,1] z = self.joints[:,2]*self.aniso_scale[0,2] mlab.points3d(x,y,z, scale_factor=3, color=(1,1,0), figure=fig) return fig ### Hull methods ###
[docs] def calc_max_dist_to_skels(self): hull = self.hull_points # (n, 3) direc = self.hull_skel['direc'] #(n, 3) dist = self.hull_skel['dist'] #(n) # true distances max_dist = np.zeros(len(hull), dtype=np.float32) # This ray has unit magnitude in the true metric ray_steps = -direc/(np.linalg.norm(direc * self.aniso_scale, axis=1)[:,None]+1e-5) # create dense cube and insert hull sh = np.max(hull,0) + 1 off = np.min(hull,0) sh -= off hull_cube = np.zeros(sh, dtype=np.bool) insert(hull_cube, hull, True, off) # cast rays through dense cube ray_cast(max_dist, hull, dist, ray_steps, hull_cube, off) # in this case the magnitude of the direc vector is 0 anyway max_dist[np.any(~np.isfinite(ray_steps), axis=1)] = 1.0 rel_dist = dist / max_dist return max_dist, rel_dist
[docs] def map_hull(self, hull_points): """ Distances take already into account the anisotropy in z (i.e. they are true distances) But all coordinates for hulls and vectors are still pixel coordinates """ self.hull_points = hull_points.astype(np.int16) hull_points = hull_points.astype(np.float32) kdt_skel = self.get_kdtree(self.all_nodes) dist_skel, ind_skel, coord_skel = self.get_knn(kdt_skel, hull_points) self.hull_skel['dist'] = dist_skel self.hull_skel['ind'] = ind_skel self.hull_skel['direc'] = coord_skel - hull_points ## NNs - Queries max_dist, rel_dist = self.calc_max_dist_to_skels() self.hull_skel['max_dist'] = max_dist self.hull_skel['rel_dist'] = rel_dist if len(self.branches): kdt_branch = self.get_kdtree(self.branches) dist_branch, ind_branch, coord_branch = self.get_knn(kdt_branch, hull_points) self.hull_branch['dist'] = dist_branch self.hull_branch['ind'] = ind_branch self.hull_branch['direc'] = coord_branch - hull_points else: self.hull_branch['dist'] = np.zeros(len(hull_points), dtype=np.float32) self.hull_branch['ind'] = None self.hull_branch['direc'] = np.zeros((len(hull_points),3), dtype=np.float32) if not np.all(np.isfinite(dist_skel)) or \ not np.all(np.isfinite(self.hull_branch['dist'])): raise ValueError("InfiniteValue") self.kdt_hull = self.get_kdtree(self.hull_points, k=1, jobs=1) # store for later usa
@utils.cache()
[docs] def get_hull_points_inner(self, cutoff=1.0/3, return_indices=False): mask = self.hull_skel['rel_dist'] < cutoff if return_indices: return self.hull_points[mask], mask.nonzero()[0] else: return self.hull_points[mask]
@utils.cache()
[docs] def get_hull_branch_direc_cutoff(self, cutoff=25, normalise=False): mask = self.hull_branch['dist'] < cutoff ret = self.hull_branch['direc'] * mask #[mask] if normalise: ret /= (self.hull_branch['dist'][:,None]+1e-5) return ret
@utils.cache()
[docs] def get_hull_branch_dist_cutoff(self, cutoff=25, normalise=True): mask = self.hull_branch['dist'] < cutoff ret = self.hull_branch['dist'] * mask #[mask] if normalise: ret = (ret > 0) return ret
@utils.cache()
[docs] def get_hull_skel_direc_rel(self): return self.hull_skel['direc'] / self.hull_skel['max_dist'][:,None]
[docs] def plot_hull(self, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.hull_points[:,0] y = self.hull_points[:,1] z = self.hull_points[:,2]*self.aniso_scale[0,2] mlab.points3d(x,y,z, scale_factor=1, color=(1,1,1), mode='cube', opacity=0.1, figure=fig) return fig
[docs] def plot_hull_inner(self, cutoff, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) inner_hull = self.get_hull_points_inner(cutoff) x = inner_hull[:,0] y = inner_hull[:,1] z = inner_hull[:,2]*self.aniso_scale[0,2] mlab.points3d(x,y,z, scale_factor=1, color=(0.8,0.8,1), mode='cube', opacity=0.1, figure=fig) return fig
[docs] def plot_vec(self, substep=15, dict_name='skel', key='direc', vec=None, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.hull_points[:,0] y = self.hull_points[:,1] z = self.hull_points[:,2]*self.aniso_scale[0,2] x, y, z = x[::substep], y[::substep], z[::substep] if vec is None: dict_ = self.hull_skel if dict_name=='skel' else self.hull_branch u = dict_[key][:,0] v = dict_[key][:,1] w = dict_[key][:,2]*self.aniso_scale[0,2] else: u = vec[:,0] v = vec[:,1] w = vec[:,2]*self.aniso_scale[0,2] u,v,w = u[::substep], v[::substep], w[::substep] mlab.quiver3d(x,y,z, u,v,w, figure=fig) return fig
[docs]class Trace(object): """ Unless otherwise state all coordinates are in skeleton system (xyz) with z-axis anisotrope and all distances are in pixels (conversion to mu: 1/100) """ def __init__(self, linked_skel=None, aniso_scale=2,max_cutoff=200, uturn_detection_k=40, uturn_detection_thresh=0.45, uturn_detection_hold=10, feature_count=7): self.aniso_scale = np.array([[1, 1, aniso_scale]], dtype=np.float32) self.skel = linked_skel self.lost_track = False self.uturn_occurred = False self.coords = utils.AccumulationArray(right_shape=3, n_init=500) self.seg_length = utils.AccumulationArray(n_init=500) self.runlengths = utils.AccumulationArray(n_init=500) self.dist_self = utils.AccumulationArray(right_shape=2, n_init=500) self.dist_skel = utils.AccumulationArray(n_init=500) self.uturn_mask = utils.AccumulationArray(n_init=500, dtype=np.bool) self.coords_cnn = utils.AccumulationArray(right_shape=3, n_init=500) self.grads = utils.AccumulationArray(right_shape=3, n_init=500) self.features = utils.AccumulationArray(right_shape=feature_count, n_init=500) self.max_cutoff = max_cutoff self.uturn_detection_k = uturn_detection_k self.uturn_detection_thresh = uturn_detection_thresh self.uturn_detection_hold = uturn_detection_hold self.kdt = utils.DynamicKDT(k=uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) self.root = 0 self.comment = ""
[docs] def new_reverted_trace(self): new_trace = Trace(self.skel, self.aniso_scale[0,2], self.max_cutoff, self.uturn_detection_k, self.uturn_detection_thresh, self.uturn_detection_hold, self.features.data.shape[1:]) new_trace.coords = utils.AccumulationArray(data=self.coords[::-1]) new_trace.seg_length = utils.AccumulationArray(data=self.seg_length[::-1]) new_trace.runlengths = utils.AccumulationArray(data=self.runlengths[-1]-self.runlengths[::-1]) new_trace.dist_self = utils.AccumulationArray(data=self.dist_self[::-1]) new_trace.dist_skel = utils.AccumulationArray(data=self.dist_skel[::-1]) new_trace.uturn_mask = utils.AccumulationArray(data=self.uturn_mask[::-1]) new_trace.coords_cnn = utils.AccumulationArray(data=self.coords_cnn[::-1]) new_trace.grads = utils.AccumulationArray(data=self.grads[::-1]) new_trace.features = utils.AccumulationArray(data=self.features[::-1]) if len(new_trace)<=self.uturn_detection_k: kdt = utils.DynamicKDT(k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) for c in new_trace.coords.data: kdt.append(c) new_trace.kdt = kdt else: new_trace.kdt = utils.DynamicKDT(points=new_trace.coords.data, k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) new_trace.root = len(self)-1 try: self.comment except AttributeError: self.comment = "" new_trace.comment = self.comment+ " R" return new_trace
[docs] def new_cut_trace(self, start, stop): new_trace = Trace(self.skel, self.aniso_scale[0,2], self.max_cutoff, self.uturn_detection_k, self.uturn_detection_thresh, self.uturn_detection_hold, self.features.data.shape[1:]) new_trace.coords = utils.AccumulationArray(data=self.coords[start:stop]) new_trace.seg_length = utils.AccumulationArray(data=self.seg_length[start:stop]) new_trace.runlengths = utils.AccumulationArray(data=self.runlengths[start:stop]-self.runlengths[start]) new_trace.dist_self = utils.AccumulationArray(data=self.dist_self[start:stop]) new_trace.dist_skel = utils.AccumulationArray(data=self.dist_skel[start:stop]) new_trace.uturn_mask = utils.AccumulationArray(data=self.uturn_mask[start:stop]) new_trace.coords_cnn = utils.AccumulationArray(data=self.coords_cnn[start:stop]) new_trace.grads = utils.AccumulationArray(data=self.grads[start:stop]) new_trace.features = utils.AccumulationArray(data=self.features[start:stop]) if len(new_trace)<=self.uturn_detection_k: kdt = utils.DynamicKDT(k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) for c in new_trace.coords.data: kdt.append(c) new_trace.kdt = kdt else: new_trace.kdt = utils.DynamicKDT(points=new_trace.coords.data, k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) if (self.root - start) >= 0 and (self.root - start) < len(new_trace): new_trace.root = self.root - start else: new_trace.root = None #np.minimum(len(new_trace)-1, self.root - start) try: self.comment except AttributeError: self.comment = "" new_trace.comment = self.comment + "C%i-%i"%(start, stop) return new_trace
def __len__(self): return len(self.coords)
[docs] def save(self, fname): utils.picklesave(self, fname)
[docs] def save_to_kzip(self, fname): trace_to_kzip(self, fname)
[docs] def add_offset(self, off): off = np.atleast_2d(off) self.coords.add_offset(off) if len(self)<=self.uturn_detection_k: kdt = utils.DynamicKDT(k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) for c in self.coords.data: kdt.append(c) self.kdt = kdt else: self.kdt = utils.DynamicKDT(points=self.coords.data, k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale)
[docs] def append(self, coord, coord_cnn=None, grad=None, features=None): self.coords.append(coord) if len(self)>1: diff = np.linalg.norm((coord - self.coords[-2]) * self.aniso_scale[0]) else: diff = 5 # just guess self.seg_length.append(diff) self.runlengths.append(self.runlength) if len(self) > self.uturn_detection_k+1: distances, indices, coordinates = self.kdt.get_knn(coord, k=self.uturn_detection_k) dist = distances.mean() else: dist = self.seg_length.ema * float(self.uturn_detection_k + 1) / 2 normalisation = self.seg_length.ema * float(self.uturn_detection_k + 1) / 2 self.dist_self.append([dist, dist/normalisation]) self.kdt.append(coord) if self.skel: dist, index, node = self.skel.get_closest_node(coord) self.dist_skel.append(dist) if grad is not None: self.grads.append(grad) if features is not None: self.features.append(features) if coord_cnn is not None: self.coords_cnn.append(coord_cnn) # Check for criteria last_dist = self.dist_self[-self.uturn_detection_hold:, 1] uturn = np.all(last_dist < self.uturn_detection_thresh) self.uturn_mask.append(uturn) if not self.uturn_occurred and uturn: # register the first u-turn self.uturn_occurred = (len(self), self.runlength) if not self.lost_track: lost = self.dist_skel.max() > self.max_cutoff if lost: self.lost_track = (len(self), self.runlength)
[docs] def append_serial(self, *args): for arg in zip(*args): self.append(*arg)
@property def avg_seg_length(self): return self.seg_length.mean() @property def runlength(self): return self.seg_length.sum() @property def avg_dist_skel(self): return self.dist_skel.mean() @property def max_dist_skel(self): return self.dist_skel.max() @property def avg_dist_self(self): return self.dist_self.mean() @property def min_dist_self(self): return self.dist_self.min()[0] @property def min_normed_dist_self(self): return self.dist_self.min()[1]
[docs] def tortuosity(self, start=None, end=None): if start is None: start = 0 if end is None: end = len(self) arc = self.runlengths[end-1] - self.runlengths[start] chord = np.linalg.norm((self.coords[end-1] - self.coords[start]) * self.aniso_scale[0]) t = arc / chord return t
[docs] def plot(self, grads=True, skel=True, rand_color=False, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) if skel and self.skel: fig = self.skel.plot_skel(fig=fig) x = self.coords[:, 0] y = self.coords[:, 1] z = self.coords[:, 2] * self.aniso_scale[0, 2] line_c = tuple(np.random.rand(3)) if rand_color else (0, 0, 0.7) point_c = line_c if rand_color else (0.6, 0.7, 0.9) mlab.plot3d(x, y, z, tube_radius=0.2, color=line_c, figure=fig) mlab.points3d(x, y, z, scale_factor=0.8, color=point_c, figure=fig) if grads and self.grads.length: x = self.coords[:, 0] y = self.coords[:, 1] z = self.coords[:, 2] * self.aniso_scale[0, 2] gx = -self.grads[:, 0] gy = -self.grads[:, 1] gz = -self.grads[:, 2] * self.aniso_scale[0, 2] mlab.quiver3d(x, y, z, gx, gy, gz, figure=fig, color=(0, 0.6, 0.2), scale_factor=3) return fig
[docs] def split_uturns(self, return_accum_pathlength=False, print_stat=False): transitions = np.diff(self.uturn_mask, axis=0, n=1) transitions = np.nonzero(transitions)[0] transitions[0::2] -= self.uturn_detection_hold # if add: end segment closer to uturn transitions[1::2] -= self.uturn_detection_hold # if subtract: start new segment closer to uturn transitions = np.minimum(np.maximum(0, transitions), len(self)) transitions = np.hstack((0, transitions, len(self))) new_traces = [] accum_pathlenghts = [] accum_dist_skel = [] accum_runlength = 0 for i in range(0, len(transitions)-1, 2): new = self.__class__(self.skel, self.aniso_scale[0, 2], self.max_cutoff, self.uturn_detection_k, self.uturn_detection_thresh, self.uturn_detection_hold) start, stop = transitions[i], transitions[i+1] if print_stat: print("cutting between %i and %i " % (start, stop)) if start<stop: # some transitions are too short coords = self.coords[start:stop] new.append_serial(coords) new_traces.append(new) # accumulate pathlenghts and dist to skel over splits runlengths = self.runlengths[start:stop] runlengths = runlengths + accum_runlength - runlengths[0] # shift accum_pathlenghts.append(runlengths) accum_runlength = runlengths[-1] if self.skel: dist_skel = self.dist_skel[start:stop].copy() if i > 0: # this makes the eval stop if the trace deviated from the # skeleton too much during the uturn max_dist_in_uturn = np.max(self.dist_skel[transitions[i-1]:start]) dist_skel[0] = np.maximum(dist_skel[0], max_dist_in_uturn) accum_dist_skel.append(dist_skel) else: accum_dist_skel.append([]) if return_accum_pathlength: return new_traces, np.hstack(accum_pathlenghts), np.hstack(accum_dist_skel) else: return new_traces
def normalised_min_dist(tr, point): dist, ind, coord = tr.kdt.get_knn(point, k=1) radius = tr.features[ind, 0] return dist, dist / radius def simple_stats(a): m = np.mean(a, axis=0) s = np.std(a, axis=0) minv = np.min(a, axis=0) maxv = np.max(a, axis=0) return np.array([m, s, minv, maxv]) def radius_hist(r): bins = np.array([0,8,14,23,35,50,80,200]) counts, bins = np.histogram(r, bins=bins, density=True) return counts def get_merge_features(main_tr, main_node, sub_tr, sub_node, end_match): m_slice_small = slice(max(0,main_node-5), main_node+5) m_slice_large = slice(max(0,main_node-25), main_node+25) m_points = main_tr.coords[m_slice_small] * main_tr.aniso_scale uu, pc_m, pc_dir_m = np.linalg.svd(m_points-m_points.mean(0)) m_feat_small = simple_stats(main_tr.features[m_slice_small]) m_feat_large = simple_stats(main_tr.features[m_slice_large]) m_radius_hist= radius_hist(main_tr.features[m_slice_large, 0]) m_tortuosity = main_tr.tortuosity() main_features= np.hstack([pc_m, pc_dir_m.ravel(), m_feat_small.ravel(), m_feat_large.ravel(), m_radius_hist, m_tortuosity]) if end_match: s_slice_small = slice(max(0,sub_node-10), sub_node) s_slice_large = slice(max(0,sub_node-50), sub_node) else: s_slice_small = slice(sub_node, sub_node+10) s_slice_large = slice(sub_node, sub_node+50) s_points = sub_tr.coords[s_slice_small] * sub_tr.aniso_scale if len(s_points)==0: pass uu, pc_s, pc_dir_s = np.linalg.svd(s_points-s_points.mean(0)) s_feat_small = simple_stats(sub_tr.features[s_slice_small]) s_feat_large = simple_stats(sub_tr.features[s_slice_large]) s_radius_hist= radius_hist(sub_tr.features[s_slice_large, 0]) s_tortuosity = sub_tr.tortuosity() sub_features = np.hstack([pc_s, pc_dir_s.ravel(), s_feat_small.ravel(), s_feat_large.ravel(), s_radius_hist, s_tortuosity]) dist = np.linalg.norm((main_tr.coords[main_node]-sub_tr.coords[sub_node])*sub_tr.aniso_scale) r_m = main_tr.features[main_node, 0] r_s = sub_tr.features[sub_node, 0] pc_dir_similarity = np.abs(np.dot(pc_dir_m, pc_dir_s.T)) joint_features = np.hstack([dist, dist/r_m, dist/r_s, 2*dist/(r_m+r_s), pc_dir_similarity.ravel()]) return main_features, sub_features, joint_features def split_tree_components(tracetree, cut=False): if tracetree.num_components==1: return [tracetree,] new_trees = [list() for i in range(tracetree.num_components)] for tr_i in tracetree.traces: c = tracetree.tr_i2comp_i[tr_i] tr = tracetree.traces[tr_i] cuts = tracetree.trace_cuts.get(tr_i, None) if cuts and cut: tr = tr.new_cut_trace(*cuts) new_trees[c].append(tr) for i in range(tracetree.num_components): new_tree = TraceTree(new_trees[i], tracetree.spine_thresh, tracetree.endpoint_thresh) new_trees[i] = new_tree return new_trees class TraceTree(object): def __init__(self, traces, spine_thresh=1.5, endpoint_thresh=0.8): """ :param traces: :param spine_thresh: """""":param spine_thresh: float How large the maximal relative distance needs to be for a loop to be retained as a spine branch :param endpoint_thresh: float Threshold of relative distance between endpoint and other trace tp count as a connection """ # Rename trace keys to smaller contigious numbers if not isinstance(traces, dict): traces = dict(zip(range(len(traces)), traces)) self.traces = traces self.trace_cuts = dict() self.pruned_traces = [] self.edge_candidates = dict() self.edges = [] self.tr_i2comp_i = None self.num_components = 1 self.aniso = np.array([[1,1,2]]) self.spine_thresh = spine_thresh self.endpoint_thresh = endpoint_thresh self.joined_kdt = None self.joined_coords = None self.joined_radii = None def build_joined_features(self): self.joined_coords = np.vstack([tr.coords for tr in self.traces.values()]) self.joined_radii = np.hstack([tr.features[:,0] for tr in self.traces.values()]) kdt = utils.DynamicKDT(self.joined_coords, n_jobs=-1, aniso_scale=[1, 1, 2], k=1) self.joined_kdt = kdt def cut_traces_inplace(self): for tr_i in self.traces: tr = self.traces[tr_i] cuts = self.trace_cuts.get(tr_i, None) if cuts: new_tr = tr.new_cut_trace(*cuts) self.traces[tr_i] = new_tr def to_kzip(self, fname, save_loops=False, save_edge_candiates=False, add_edges=False, save_edges=False): fname = os.path.expanduser(fname) fpath, comment_name = os.path.split(fname) skel_objs = [] component_annos = [] for c in range(self.num_components): skel_obj = knossos_skeleton.Skeleton() skel_objs.append(skel_obj) anno_ = knossos_skeleton.SkeletonAnnotation() anno_.scaling = (9.0, 9.0, 20.0) anno_.setComment(comment_name+"-c%i"%c) skel_obj.add_annotation(anno_) component_annos.append(anno_) # Save all cut traces to own anno-obj of their component node_mappings = dict() for tr_i in self.traces: if self.tr_i2comp_i is not None: c = self.tr_i2comp_i[tr_i] else: c = 0 anno = component_annos[c] tr = self.traces[tr_i] cuts = self.trace_cuts.get(tr_i, None) if cuts: tr = tr.new_cut_trace(*cuts) _, node_mapping = trace_to_anno(tr, fname, anno) node_mappings[tr_i] = node_mapping # Save all edges (between cut points) to a edge annotation if len(self.edges) and save_edges: edge_anno = knossos_skeleton.SkeletonAnnotation() edge_anno.scaling = (9.0, 9.0, 20.0) edge_anno.setComment("Edges-"+comment_name) skel_obj_edges = knossos_skeleton.Skeleton() skel_obj_edges.add_annotation(edge_anno) for e in self.edges: try: main, sub = e main_node_i, sub_node_i = self.edge_candidates[tuple(e)][1] except KeyError: # MST might turn around edge order main_node_i, sub_node_i = self.edge_candidates[tuple(e)[::-1]][1] main, sub = e[::-1] main_node = knossos_skeleton.SkeletonNode() x,y,z = np.round(self.traces[main].coords[main_node_i]).astype(np.int16) main_node.from_scratch(edge_anno, x,y,z) edge_anno.addNode(main_node) sub_node = knossos_skeleton.SkeletonNode() x,y,z = np.round(self.traces[sub].coords[sub_node_i]).astype(np.int16) sub_node.from_scratch(edge_anno, x,y,z) edge_anno.addNode(sub_node) edge_anno.addEdge(main_node, sub_node) if add_edges: main_cut = self.trace_cuts.get(main, [0, None])[0] sub_cut = self.trace_cuts.get(sub, [0, None])[0] main_i = main_node_i - main_cut sub_i = sub_node_i - sub_cut try: n_main = node_mappings[main][main_i] n_sub = node_mappings[sub][sub_i] n_main.annotation.addEdge(n_main, n_sub) except: pass outfile = fpath + "/edges-" + comment_name + '.k.zip' skel_obj_edges.to_kzip(outfile) # As Node==Edge==Node in one skeleton Tree (for making GT in knossos) if save_edge_candiates and self.edge_candidates: edge_candiate_anno = knossos_skeleton.SkeletonAnnotation() edge_candiate_anno.scaling = (9.0, 9.0, 20.0) edge_candiate_anno.setComment(comment_name+"-Edge-Candiates") skel_obj_candidates = knossos_skeleton.Skeleton() skel_obj_candidates.add_annotation(edge_candiate_anno) for e in self.edge_candidates: main, sub = e main_node_i, sub_node_i = self.edge_candidates[e][1] main_node = knossos_skeleton.SkeletonNode() x,y,z = np.round(self.traces[main].coords[main_node_i]).astype(np.int16) main_node.from_scratch(edge_candiate_anno, x,y,z) main_node.setComment(comment_name+"-M%i_%i-S%i_%i-main" %(main, main_node_i, sub, sub_node_i)) edge_candiate_anno.addNode(main_node) sub_node = knossos_skeleton.SkeletonNode() x,y,z = np.round(self.traces[sub].coords[sub_node_i]).astype(np.int16) sub_node.from_scratch(edge_candiate_anno, x,y,z) sub_node.setComment(comment_name+"-M%i_%i-S%i_%i-sub" %(main, main_node_i, sub, sub_node_i)) edge_candiate_anno.addNode(sub_node) edge_candiate_anno.addEdge(main_node, sub_node) outfile = fpath + "/candidates-" + comment_name + '.k.zip' skel_obj_candidates.to_kzip(outfile) if save_loops: skel_obj_loops = knossos_skeleton.Skeleton() for i,t in enumerate(self.pruned_traces): anno_loop, _ = trace_to_anno(t, comment_name+'-loop%i'%i) skel_obj_loops.add_annotation(anno_loop) outfile = fname + '-loops.k.zip' skel_obj_loops.to_kzip(outfile) for i,skel_obj in enumerate(skel_objs): outfile = fname + '-c%i.k.zip' %i skel_obj.to_kzip(outfile) def is_loop(self, trace, traces): """ :param trace: Trace test candidate :param traces: list of Traces :return: bool """ if not len(traces): return False # Determine the average distance of the two end points to all other # traces. Use the distances normalised by the radius of the other trace end_points = trace.coords[[0,-1]] relative_distances = np.ones(len(traces)) * np.inf for i,tr in enumerate(traces.values()): if tr == trace: # don't compare trace to itself, would be loop always continue dist, ind, coord = tr.kdt.get_knn(end_points, k=1) radii = tr.features[ind, 0] relative_distances[i] = (dist/radii).mean() k = relative_distances.argmin() tr_i = traces.keys()[k] if relative_distances[k] < self.endpoint_thresh: #now check if there exists point that is farther away from main dist, rel_dist = normalised_min_dist(traces[tr_i], trace.coords.data) max_point = rel_dist.argmax() if rel_dist[max_point] >= self.spine_thresh : is_loop = False cut_a = 0 if rel_dist[0] < rel_dist[-1] else len(rel_dist) cut_b = max_point cut_0 = min(cut_a, cut_b) cut_1 = max(cut_a, cut_b) assert cut_0 < cut_1 cuts = (cut_0, cut_1) else: is_loop = True cuts = None else: is_loop = False cuts = None return is_loop, cuts def closest_approach(self, tr_a, tr_b): """ :param tr_a: Trace :param tr_b: Trace :return: """ b0toa = normalised_min_dist(tr_a, tr_b.coords[0])[1] b1toa = normalised_min_dist(tr_a, tr_b.coords[-1])[1] a0tob = normalised_min_dist(tr_b, tr_a.coords[0])[1] a1tob = normalised_min_dist(tr_b, tr_a.coords[-1])[1] case = np.argmin([b0toa, b1toa, a0tob, a1tob]) geometrict_dist = np.min([b0toa, b1toa, a0tob, a1tob]) if geometrict_dist > self.spine_thresh : return None end_match = case in [1,3] a_is_main = case in [0, 1] main_tr = tr_a if case in [0,1] else tr_b sub_tr = tr_b if case in [0,1] else tr_a slice_20 = slice(-20, None) if end_match else slice(None, 20) sub_coords = sub_tr.coords[slice_20] distances, indices, coordinates = main_tr.kdt.get_knn(sub_coords, k=1) max_seg_length = sub_tr.seg_length[slice_20].max() sub_merge_candidates = (distances < max_seg_length).nonzero()[0] if len(sub_merge_candidates): sub_node = sub_merge_candidates[0] if end_match else sub_merge_candidates[-1] else: sub_node = distances.argmin() main_node = indices[sub_node] # Take the main node which was found in knn if end_match: # For end_match the index needs to be shifted by the trace length to comply with the indices of sub_coords sub_node += len(sub_tr) - np.minimum(20, len(sub_tr)) cut_start = 0 if end_match else sub_node cut_end = sub_node+1 if end_match else len(sub_tr) assert cut_end-cut_start>0 # Check for cases where there is a spine loop if (cut_start==sub_node and end_match): assert cut_start==0 cut_end = len(sub_tr) end_match = not end_match elif (cut_end==sub_node and not end_match): assert cut_end==len(sub_tr) cut_start = 0 end_match = not end_match cuts = (cut_start, cut_end) nodes = (main_node, sub_node) try: feat = get_merge_features(main_tr,main_node,sub_tr,sub_node, end_match) except: pass # merge_feat_main, merge_feat_sub, merge_feat_joint = feat return geometrict_dist, a_is_main, nodes, cuts, feat def make_merge_graph(self): n = len(self.traces) msd_list= [] keys = self.traces.keys() # For all pairwise traces find closest approach, node/cuts indices and # features for the edge classifier, collect positive edges in "msd_list" for s in range(n): for t in range(s+1, n): tr_a = keys[s] tr_b = keys[t] tmp = self.closest_approach(self.traces[tr_a], self.traces[tr_b]) if tmp is None: # If Components are disconnected still add them msd_list.append([tr_a, tr_a, 0]) msd_list.append([tr_b, tr_b, 0]) continue geometrict_dist, a_is_main, nodes, cuts, feat = tmp if np.isclose(geometrict_dist, 0.0): geometrict_dist = 0.1 # otherwise connected components will consider this as split if a_is_main: main = tr_a sub = tr_b else: main = tr_b sub = tr_a main_coord = self.traces[main].coords[nodes[0]] sub_coord = self.traces[sub].coords[nodes[1]] coords = (main_coord, sub_coord) # if edge_classifier(feat) > thresh: # don't add edges which are not classified self.edge_candidates[(main, sub)] = [geometrict_dist, nodes, cuts, feat, coords] msd_list.append([main, sub, geometrict_dist]) if len(msd_list)==0: return # Create MST from edge graph (actually MST-Forest) main, sub, dist = np.array(msd_list).T a = np.hstack([main, sub]) b = np.hstack([sub, main]) values = np.hstack([dist, dist]) adj_mat = sparse.csr_matrix(sparse.coo_matrix( (values, (a,b)) )) mst = csgraph.minimum_spanning_tree(adj_mat) edges_mst = np.array(mst.nonzero()).T self.edges = edges_mst # For all MST-edges update the cuts (cut as much as possible to cover merge positions) for edge_mst in edges_mst: try: sub = edge_mst[1] tmp = self.edge_candidates[tuple(edge_mst)] except KeyError: # MST might turn around edge order sub = edge_mst[0] tmp = self.edge_candidates[tuple(edge_mst)[::-1]] old_cuts = self.trace_cuts.get(sub, None) if old_cuts is None: new_cuts = tmp[2] else: new_cuts = (np.minimum(old_cuts[0], tmp[2][0]), np.maximum(old_cuts[1], tmp[2][1])) self.trace_cuts[sub] = new_cuts # If edges candidates were classified negative, the components oft the # MST-forest must be split # Returns unconnected nodes (empty slots) as component too! num, labels = csgraph.connected_components(mst, directed=False) # Therefore select only the stuff which is in keys comp_names, components = np.unique(labels[keys], return_inverse=True) self.tr_i2comp_i = dict(zip(keys, components)) self.num_components = comp_names.size def simplify(self, profile=False): if profile: tt = utils.Timer() keep_traces = {} pruned_traces = {} traces = dict(self.traces) keys = np.array(traces.keys()) trace_lengths = np.array([traces[tr_i].runlength for tr_i in keys]) keys_sorted = keys[np.argsort(trace_lengths)] for tr_i in keys_sorted: tr = traces[tr_i] is_loop, cuts = self.is_loop(tr, traces) if is_loop: pruned_traces[tr_i] = traces.pop(tr_i) else: # Dont cut traces here, it might mess up connection parts #if cuts: # cut traces must be put to stack again because they might be a loop now # cut_tr = tr.new_cut_trace(*cuts) # traces[tr_i] = cut_tr #else: keep_traces[tr_i] = tr self.pruned_traces = pruned_traces self.traces = keep_traces if profile: tt.check(name='prune loops') # If all traces are loops with another trace take the largest single trace if len(self.traces)==0: tr_lengths = np.array([(tr_i, tr.runlength) for tr_i,tr in pruned_traces.items()]) i = np.argmax(tr_lengths[:,1]) i = int(tr_lengths[i,0]) tr0 = pruned_traces.pop(i) self.traces[i] = tr0 self.make_merge_graph() if profile: tt.check("merge graph") def make_segment_lenghts(bone): segment_lengths = np.linalg.norm(np.diff(bone, n=1, axis=0) * np.array([[1, 1, 2]]), axis=1) segment_lengths = np.hstack(([0, ], segment_lengths)) segment_lengths[0] = segment_lengths[1] * 0.5 segment_lengths[1:-2] = (segment_lengths[1:-2] + segment_lengths[ 2:-1]) * 0.5 segment_lengths[-1] = segment_lengths[-1] * 0.5 return segment_lengths def runlength_metric(path_lengths, distances, cut_start=10, cut_max=200, num=50): cutoffs = np.linspace(cut_start, cut_max, num=num) runlengths = utils.AccumulationArray() correct_lenghts = utils.AccumulationArray() for cutoff in cutoffs: for i in range(len(distances)): larger = np.nonzero(distances[i] >= cutoff)[0] if len(larger): correct_lenghts.append(path_lengths[i][larger[0]-1]) else: correct_lenghts.append(path_lengths[i][-1]) mean_correct_lenght = correct_lenghts.mean() correct_lenghts.clear() runlengths.append(mean_correct_lenght) return runlengths.data, cutoffs def runlength_metric_GT(trace, skel=None, cut_start=10, cut_max=200, num=20): if skel is None: skel = trace.skel cutoffs = np.linspace(cut_start, cut_max, num=num) runlengths = np.zeros(num) trace_points = trace.coords.data trace_kdt = utils.KDT(radius=cut_max, n_jobs=-1) trace_kdt.fit(trace_points * np.array([[1,1,2]])) for edge, bone in skel.bones.items(): segment_lengths = make_segment_lenghts(bone) dist, ind = trace_kdt.radius_neighbors(bone * np.array([[1,1,2]])) for i, cutoff in enumerate(cutoffs): was_traced = np.zeros(len(bone), dtype=np.bool) for k in range(len(dist)): was_traced[k] = np.any(dist[k]<=cutoff) runlengths[i] += segment_lengths[was_traced].sum() return cutoffs, runlengths def trace_to_anno(trace_xyz, name, anno=None, root=None): if isinstance(trace_xyz, Trace): feature_avail = len(trace_xyz.features)==len(trace_xyz) else: feature_avail = True radius = 1.0 if anno is None: anno = knossos_skeleton.SkeletonAnnotation() anno.scaling = (9.0, 9.0 ,20.0) anno.setComment(os.path.split(name)[1]) node_mapping = dict() last_node = knossos_skeleton.SkeletonNode() trace_coords = np.round(trace_xyz.coords).astype(np.int16) # trace_xyz.coords.data.astype(np.int16) if feature_avail: radius = trace_xyz.features[0, 0] last_node.from_scratch(anno, trace_coords[0,0], trace_coords[0,1], trace_coords[0,2], radius=radius) if feature_avail: last_node.setDataElem("axoness_proba0", trace_xyz.features[0, 1]) last_node.setDataElem("axoness_proba1", trace_xyz.features[0, 2]) last_node.setDataElem("axoness_proba2", trace_xyz.features[0, 3]) last_node.setDataElem("spiness_proba0", trace_xyz.features[0, 4]) last_node.setDataElem("spiness_proba1", trace_xyz.features[0, 5]) last_node.setDataElem("spiness_proba2", trace_xyz.features[0, 6]) anno.addNode(last_node) node_mapping[0] = last_node for k in range(1, len(trace_coords)): coord = trace_coords[k] if feature_avail: radius = trace_xyz.features[k,0] new_node = knossos_skeleton.SkeletonNode() new_node.from_scratch(anno, coord[0], coord[1], coord[2], radius=radius) if feature_avail: last_node.setDataElem("axoness_proba0", trace_xyz.features[k, 1]) last_node.setDataElem("axoness_proba1", trace_xyz.features[k, 2]) last_node.setDataElem("axoness_proba2", trace_xyz.features[k, 3]) last_node.setDataElem("spiness_proba0", trace_xyz.features[k, 4]) last_node.setDataElem("spiness_proba1", trace_xyz.features[k, 5]) last_node.setDataElem("spiness_proba2", trace_xyz.features[k, 6]) node_mapping[k] = new_node anno.addNode(new_node) last_node.addChild(new_node) last_node = new_node if root is None: if isinstance(trace_xyz, Trace) and trace_xyz.root is not None: root = trace_xyz.root if root is not None: node_mapping[root].setRoot() return anno, node_mapping
[docs]def trace_to_kzip(trace_xyz, fname): skel_obj = knossos_skeleton.Skeleton() anno, node_mapping = trace_to_anno(trace_xyz, fname) skel_obj.add_annotation(anno) outfile = fname + '.k.zip' skel_obj.to_kzip(outfile)
def trace_to_kzip_multi(traces, fname): if isinstance(traces, dict): traces = traces.values() skel_obj = knossos_skeleton.Skeleton() for i, trace_xyz in enumerate(traces): if not isinstance(trace_xyz, np.ndarray): anno, node_mapping = trace_to_anno(trace_xyz, fname+"_%i"%i) else: trace_xyz = np.round(trace_xyz).astype(np.int16) anno = knossos_skeleton.SkeletonAnnotation() anno.scaling = (9.0, 9.0 ,20.0) anno.setComment(os.path.split(fname)[1]+"_%i"%i) last_node = knossos_skeleton.SkeletonNode() last_node.from_scratch(anno, trace_xyz[0,0], trace_xyz[0,1], trace_xyz[0,2]) last_node.setRoot() anno.addNode(last_node) for coord in trace_xyz[1:]: new_node = knossos_skeleton.SkeletonNode() new_node.from_scratch(anno, coord[0], coord[1], coord[2]) anno.addNode(new_node) last_node.addChild(new_node) last_node = new_node skel_obj.add_annotation(anno) outfile = fname + '.k.zip' skel_obj.to_kzip(os.path.expanduser(outfile)) def bbox_cube_anno(off_xyz, sz_xyz, comment="?", cross_edges=False): off_xyz = np.array(off_xyz) sz_xyz = np.array(sz_xyz) cords = [off_xyz+sz_xyz*[0,0,0], off_xyz+sz_xyz*[1,0,0], off_xyz+sz_xyz*[1,1,0],#2 off_xyz+sz_xyz*[0,1,0], off_xyz+sz_xyz*[0,0,1],#4 off_xyz+sz_xyz*[1,0,1], off_xyz+sz_xyz*[1,1,1],#6 off_xyz+sz_xyz*[0,1,1], ] cords = np.array(cords) anno = knossos_skeleton.SkeletonAnnotation() anno.scaling = (9.0, 9.0 ,20.0) anno.setComment("%s: %s - %s"%(comment, off_xyz,sz_xyz)) nodes = [] for x,y,z in cords: new_node = knossos_skeleton.SkeletonNode() new_node.from_scratch(anno,x,y,z) anno.addNode(new_node) nodes.append(new_node) edges = [(0,1),(0,3),(0,4), (1,2),(1,5), (3,2),(3,7), (4,5),(4,7), (6,7),(6,5),(6,2)] for n1, n2 in edges: nodes[n1].addChild(nodes[n2]) if cross_edges: for n1 in anno.nodes: for n2 in anno.nodes: n1.addChild(n2) return anno