Source code for elektronn2.utils.d3viz.formatting

"""
Visualisiation code taken form Theano
Original Author: Christof Angermueller <cangermueller@gmail.com>
Adapted with permission for the ELEKTRONN2 Toolkit by Marius Killinger 2016
Note that this code is licensed under the original terms of Theano (see license
containing directory).
"""

from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, \
    super, zip

import future.utils
import re
import os
import shutil
import logging
from collections import OrderedDict

from ...neuromancer.graphutils import TaggedShape


logger = logging.getLogger('elektronn2log')


__path__ = os.path.dirname(os.path.realpath(__file__))

pydot_imported = False
try:
    # pydot2 supports py3
    import pydotplus as pd


    if pd.find_graphviz():
        pydot_imported = True
except ImportError:
    try:
        # fall back on pydot if necessary
        import pydot as pd


        if pd.find_graphviz():
            pydot_imported = True
    except ImportError:
        pass  # tests should not fail on optional dependency

if not pydot_imported:
    logger.warning('Failed to import pydot/pydotplus. You must install '
                   'graphviz and either pydot or pydotplus for '
                   '`PyDotFormatter` to work.')


[docs]def sort(model, select_outputs): graph_sorted = OrderedDict() graph_unsorted = OrderedDict() for n in model.nodes.values(): if n.is_source: graph_sorted[n] = True else: graph_unsorted[n] = True for n in graph_unsorted: if not len(n.children) and select_outputs: # output node ### TODO raise NotImplementedError( "Select outputs does not work. Make sure " "that outputs which are needed from scan " "are not removed, that edges to removed " "nodes are not drawn and and that after " "removing unwanted output nodes in one sort " "there will be new unwanted terminal nodes") if n.name in select_outputs: graph_sorted[n] = True else: graph_sorted[n] = True return graph_sorted
[docs]class PyDotFormatter2(object): """Create `pydot` graph object from Theano function. Parameters ---------- compact : bool if True, will remove intermediate variables without name. Attributes ---------- node_colors : dict Color table of node types. apply_colors : dict Color table of apply nodes. shapes : dict Shape table of node types. """ def __init__(self, compact=True): """Construct PyDotFormatter object.""" if not pydot_imported: raise ImportError( 'Failed to import pydot/pydotplus. You must install ' 'graphviz and either pydot or pydotplus for ' '`PyDotFormatter` to work.') self.compact = compact self.__node_prefix = 'n' def __add_node(self, node): """Add new node to node list and return unique id. Parameters ---------- node : Theano graph node Apply node, tensor variable, or shared variable in compute graph. Returns ------- str Unique node id. """ assert node not in self.__nodes _id = '%s%d' % (self.__node_prefix, len(self.__nodes) + 1) self.__nodes[node] = _id return _id def __node_id(self, node): """Return unique node id. Parameters ---------- node : Theano graph node Apply node, tensor variable, or shared variable in compute graph. Returns ------- str Unique node id. """ if node in self.__nodes: return self.__nodes[node] else: return self.__add_node(node)
[docs] def get_node_props(self, node): cls_name = node.__class__.__name__ if cls_name.startswith('_'): cls_name = cls_name[1:] __node_id = self.__node_id(node) node_type = 0 # normal if node.is_source: node_type = 1 elif len(node.children)==0: node_type = 2 nparams = {} nparams['name'] = __node_id nparams['label'] = "%s - %s" % (cls_name, node.name) nparams['profile'] = [0, 1e-5] nparams['style'] = 'filled' nparams['type'] = 'colored' nparams['shape'] = 'ellipse' nparams['fillcolor'] = '#008000' # 'green' if cls_name=='Conv': nparams['shape'] = 'invtrapezium' if cls_name=='UpConv': nparams['shape'] = 'trapezium' if cls_name=='Perceptron': nparams['shape'] = 'octagon' if cls_name in ['GRU', 'LSTM']: nparams['shape'] = 'doubleoctagon' if cls_name=='Concat': nparams['shape'] = 'house' if cls_name=='ScanN': nparams['shape'] = 'doublecircle' nparams['fillcolor'] = 'red' if 'loss' in cls_name.lower() or 'nll' in cls_name.lower(): nparams['shape'] = 'diamond' nparams['fillcolor'] = '#FFAA22' if cls_name=='ValueNode': nparams['shape'] = 'box' if node_type==2: nparams['fillcolor'] = 'blue' nparams['shape'] = 'box' elif node_type==1: nparams['fillcolor'] = 'yellow' nparams['shape'] = 'box' if isinstance(node.shape, TaggedShape) and node.shape.ndim: nparams['dtype'] = node.shape # .ext_repr elif node.shape: nparams['dtype'] = node.shape nparams['tag'] = None # 'tag' # Noone Not needed? nparams['node_type'] = 'node type' # not needed? nparams['apply_op'] = 'apply_op' # not needed? return nparams
def __call__(self, model, select_outputs=None): """Create pydot graph from function. Parameters ---------- model: model object Returns ------- pydot.Dot Pydot graph of `fct` """ graph = pd.Dot() self.__nodes = {} if select_outputs is not None and isinstance(select_outputs, str): select_outputs = [select_outputs, ] nodes = sort(model, select_outputs) # Create nodes for node in nodes: nparams = self.get_node_props(node) pd_node = dict_to_pdnode(nparams) graph.add_node(pd_node) # Create edges for node in nodes: for i, c in enumerate(node.children.values()): if c.__class__.__name__=="ScanN": if node in c.in_memory: # print("Skippnig",node,'for',c) continue p_id = self.__node_id(node) c_id = self.__node_id(c) edge_params = {} edge_params['color'] = 'black' edge_label = " " # str(i) pdedge = pd.Edge(p_id, c_id, label=edge_label, **edge_params) graph.add_edge(pdedge) if node.__class__.__name__=='ScanN': self.add_scan_edges(node, graph, nodes) return graph
[docs] def add_scan_edges(self, scan, graph, nodes): n = str(scan.n_steps) if scan.n_steps else "variable" # if scan.out_memory: # out = scan.out_memory # else: # out = scan.step_result out = [] for i in scan.out_memory_sl: name = scan.output_names[i] for nd in nodes: if nd.name==name: out.append(nd) break else: out.append("NotFound") # out = [nodes[name] for name in scan.output_name] for p, c in zip(out, scan.in_memory): for ci in c.children.values(): if ci.__class__.__name__=="ScanN": continue # if c in c.in_memory: p_id = self.__node_id(p) c_id = self.__node_id(ci) edge_params = {} edge_params['color'] = 'red' edge_params['constraint'] = False edge_params['penwidth'] = 3 edge_label = n + "x recur.\nreplace %s" % c.name pdedge = pd.Edge(p_id, c_id, label=edge_label, **edge_params) graph.add_edge(pdedge) if scan.in_iterate: for p, c in zip(scan.in_iterate, scan.in_iterate_0): p_id = self.__node_id(p) c_id = self.__node_id(c) edge_params = {} edge_params['color'] = 'red' edge_params['constraint'] = False edge_params['penwidth'] = 3 edge_label = n + "x recur.\niteration" pdedge = pd.Edge(p_id, c_id, label=edge_label, **edge_params) graph.add_edge(pdedge)
[docs]def dict_to_pdnode(d): """Create pydot node from dict.""" e = dict() for k, v in d.items(): if v is not None: if isinstance(v, list): v = '\t'.join([str(x) for x in v]) else: v = str(v) v = str(v) v = v.replace('"', '\'') e[k] = v pynode = pd.Node(**e) return pynode
[docs]def replace_patterns(x, replace): """Replace `replace` in string `x`. Parameters ---------- s : str String on which function is applied replace : dict `key`, `value` pairs where key is a regular expression and `value` a string by which `key` is replaced """ for from_, to in replace.items(): x = x.replace(str(from_), str(to)) return x
[docs]def escape_quotes(s): """Escape quotes in string. Parameters ---------- s : str String on which function is applied """ s = re.sub(r'''(['"])''', r'\\\1', s) return s
[docs]def visualise_model(model, outfile, copy_deps=True, select_outputs=None, image_format='png', *args, **kwargs): """ Parameters ---------- model : model object outfile : str Path to output HTML file. copy_deps : bool, optional Copy javascript and CSS dependencies to output directory. Notes ----- This function accepts extra parameters which will be forwarded to :class:`theano.d3viz.formatting.PyDotFormatter`. """ outfile = os.path.expanduser(outfile) # Create DOT graph formatter = PyDotFormatter2(*args, **kwargs) graph = formatter(model, select_outputs=select_outputs) graph.write(outfile + '.' + image_format, prog='dot', format=image_format) dot_graph_raw = graph.create_dot() if not future.utils.PY2: dot_graph_raw = dot_graph_raw.decode('utf8') dot_graph = escape_quotes(dot_graph_raw).replace('\n', '').replace('\r', '') # Create output directory if not existing outdir = os.path.dirname(outfile) if not outdir=='' and not os.path.exists(outdir): os.makedirs(outdir) # Read template HTML file template_file = os.path.join(__path__, 'html', 'template.html') with open(template_file) as f: template = f.read() # Copy dependencies to output directory src_deps = __path__ if copy_deps: dst_deps = 'd3viz' for d in ['js', 'css']: dep = os.path.join(outdir, dst_deps, d) if not os.path.exists(dep): shutil.copytree(os.path.join(src_deps, d), dep) else: dst_deps = src_deps # Replace patterns in template replace = {'%% JS_DIR %%': os.path.join(dst_deps, 'js'), '%% CSS_DIR %%': os.path.join(dst_deps, 'css'), '%% DOT_GRAPH %%': dot_graph,} html = replace_patterns(template, replace) # Write HTML file with open(outfile + '.html', 'w') as f: f.write(html) graph.write(outfile + '.' + image_format, prog='dot', format=image_format)