Connection Weights in Nengo


For most models, we do not have to worry about synaptic connection weights in Nengo. They are automatically derived using the NEF formulas based on our specification of what operation needs to be performed. However, sometimes we may want to explicitly do work with these weights. This would be needed for exploring various learning rules, and also for modelling the effects of randomly eliminating synaptic connections.

To use weights in Nengo, we need to create a non-decoded termination. This is different from the normal DecodedTermination that we normally use in that it explicitly specifies the connection weights for each pair of neurons. It should be noted that this will lead to a slower simulation, since the implementation of normal DecodedTerminations is such that Nengo takes advantage of the fact that the resulting synaptic connection weight matrices are always of low rank.

As per the Neural Engineering Framework, the synaptic connection weights can be found by multiplying the decoding vectors of the first ensemble by the encoding vectors of the second ensemble. The following function will calculate this and create the new connection.

<font color="orange"><b>from</font></b> ca.nengo.util <font color="orange"><b>import</font></b> MU
<font color="orange"><b>def</b></font> make_weights(network,source,origin,target,termination,pstc):
   source=network.getNode(source)
   target=network.getNode(target)
   decoder=source.getOrigin(origin).decoders
   w=MU.prod(target.encoders,MU.transpose(decoder))
   t=target.addTermination(termination,w,pstc,<font color="green">False</font>
   network.addProjection(source.getOrigin(<font color="blue">'AXON'</font>),t)

We can use this function in Nengo by clicking on the network we are currently working with (NOTE: select the network, not either of the two ensembles). For example, if we want to connect the neural ensemble ‘A’ to the neural ensemble ‘B’ using the ‘X’ origin (i.e. no non-linear calculations), then we could do the following:

make_weights(that,<font color="blue">'A'</font>,<font color="blue">'X'</font>,<font color="blue">'B'</font>,<font color="blue">'w'</font>,<font color="orange">0.01</font>)

The new termination will be called ‘w’ and it will use a post-synaptic time constant of 0.01.

If we want to perform a non-linear operation with this connection, define a decoded origin that does the operation you want, and then replace ‘X’ in the above command with the name of the decoded origin that you created.

We may also want to add a linear transformation onto this connection. With a decoded termination, this is done by specifying a transformation matrix. We can add this capability into our function as follows

<font color="orange"><b>from</font></b> ca.nengo.util <font color="orange"><b>import</font></b> MU
<font color="orange"><b>def</font></b> make_weights(network,source,origin,target,termination,pstc,transform=<font color="green">None</font>):
    source=network.getNode(source)
    target=network.getNode(target)      
    decoder=source.getOrigin(origin).decoders

<b><font color="orange">if</b></font> transform <font color="orange"><b>is not</b></font> <font color="green">None</font>:
        decoder=MU.prod(decoder,transform)

w=MU.prod(target.encoders,MU.transpose(decoder))    
    t=target.addTermination(termination,w,pstc,<font color="green">False</font>)
    network.addProjection(source.getOrigin(<font color="blue">'AXON'</font>),t)

We can now optionally specify a transformation matrix when we create the weights.

make_weights(that,<font color="blue">'A'</font>,<font color="blue">'X'</font>,<font color="blue">'B'</font>,<font color="blue">'w'</font>,<font color="orange">0.01</font>,transform=[[<font color="orange">1,0.5,2</font>],[<font color="orange">-0.3,0,0</font>]])

We can also modify the connection weights in any way that we desire before creating the termination. For example, we may wish to examine what happens when connections are destroyed by setting some of the values in the matrix to 0. The following function helps you to do this by randomly setting a given proportion of the values in a matrix to 0.

<b><font color="orange">import</b></font> <font color="red">random</font>
<font color="orange"><b>def</b></font> destroy_connections(w,proportion):
    <b><font color="orange">if</b></font> proportion<font color="green">></font><font color="orange">1</font>: proportion=<font color="orange">1</font>
    total=<font color="green">len</font>(w)<font color="green">*len</font>(w)

pool=<font color="green">list</font>(<font color="green">xrange</font>(total))
    <b><font color="orange">for</b></font> i <b><font color="orange">in</b></font> <font color="green">xrange</font>(<font color="green">int</font>(proportion<font color="green">*</font>total)):
        x=<font color="red">random</font>.randrange(total-i)        
        w[x/<font color="green">len</font>(w)][x<font color="green">%len</font>(w)]=<font color="orange">0</font>
        pool[x]=pool[total-i<font color="orange">-1</font>]        
    <b><font color="orange">return</b></font> w

We can modify our weight connection function to use this as follows

<b><font color="orange">from</b></font> ca.nengo.util <font color="orange"><b>import</b></font> MU
<b><font color="orange">def</b></font> make_weights(network,source,origin,target,termination,pstc,transform=<font color="green">None</font>,destroy=<font color="orange">0</font>):
    source=network.getNode(source)
    target=network.getNode(target)      
    decoder=source.getOrigin(origin).decoders

<b><font color="orange">if</b></font> transform <b><font color="orange">is not</b></font> <font color="green">None</font>:
        decoder=MU.prod(decoder,transform)

w=MU.prod(target.encoders,MU.transpose(decoder))

<b><font color="orange">if</b></font> destroy<font color="green">></font><font color="red">0</font>:
        destroy_connections(w,destroy)

t=target.addTermination(termination,w,pstc,<font color="green">False</font>)
    network.addProjection(source.getOrigin(<font color="blue">'AXON'</font>),t)

We can now make connections and then destroy 25% of the connections by doing the following

make_weights(that,<font color="blue">'A'</font>,<font color="blue:>'X'</font>,<font color="blue">'B'</font>,<font color="blue">'w'</font>,<font color="orange">0.01</font>,destroy=<font color="orange">0.25</font>)

You can download the code for this example here. Put it into your Nengo directory, open the script console, and type “run weights.py”.