[Scipy-svn] r2167 - trunk/Lib/sandbox/ann

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Aug 18 12:41:56 EDT 2006


Author: fred.mailhot
Date: 2006-08-18 11:41:53 -0500 (Fri, 18 Aug 2006)
New Revision: 2167

Modified:
   trunk/Lib/sandbox/ann/srn.py
Log:
Refactoring SRN based on MLP. Simple 1-step backprop only.


Modified: trunk/Lib/sandbox/ann/srn.py
===================================================================
--- trunk/Lib/sandbox/ann/srn.py	2006-08-18 16:40:55 UTC (rev 2166)
+++ trunk/Lib/sandbox/ann/srn.py	2006-08-18 16:41:53 UTC (rev 2167)
@@ -8,19 +8,13 @@
 class srn:
     """Class to define, train and test a simple recurrent network
     a.k.a. 'Elman net' (cf. Elman 1991's Machine Learnig paper,inter alia)
-
-    ************************* NOTA BENE 2006-06-23 ************************
-    * This is obviously still very incomplete. The initial implementation *
-    * will only have straightforward backprop-through-time (with the      *
-    * option to truncate).                                                *
-    ***********************************************************************
     """
 
     _type = 'srn'
     _outfxns = ('linear','logistic','softmax')
-    _algs = ('bptt')         # hopefully eventually RTRL and EKF
+    _alg = ('srn')
 
-    def __init__(self,ni,nh,no,f,tau=-1,w=None):
+    def __init__(self,ni,nh,no,f,h=-1,w=None):
         """ Set up instance of srn. Initial weights are drawn from a 
         zero-mean Gaussian w/ variance is scaled by fan-in.
         (see Bishop 1995 for justification)
@@ -31,7 +25,6 @@
             f   - string description of output unit activation fxn;
                     one of {'linear','logistic','softmax'}
                     (n.b. hidden/context units use tanh)
-            h   - truncation constant for bptt(h)
             w   - initialized 1-d weight vector
         """
         if f not in self._outfxns:
@@ -39,14 +32,15 @@
             self.outfxn = 'linear'
         else:
             self.outfxn = f
+        # set up layers of units
         self.ni = ni
         self.nh = nh
-        self.nc = nh    # context units
+        self.nc = nh
         self.no = no
-        self.alg = self._algs[1]
         self.z = zeros((h,nh),dtype=Float)        # hidden activations for 1 epoch
         self.c = zeros((h,nh),dtype=Float)        # context activations for 1 epoch
         self.o = zeros((h,no),dtype=Float)       # output activiation for 1 epoch
+        self.p = zeros((nh,nw,nw),dtype=Float)
         if w:
             self.nw = size(w)
             self.wp = w
@@ -95,10 +89,11 @@
                                 self.b2.reshape(size(self.b2))])
 
     def fwd(self,x,w=None,hid=False):
-        """ Propagate values forward through the net. This (i) feeds the current input
-        and values of the context units (i.e. hidden vals from previous time step)
-        into the hidden layer, which is then (ii) fed to the output layer, and 
-        (iii) copied to the context layer
+        """ Propagate values forward through the net. 
+        This involves the following steps:
+        (i) feeds the current input and context values to the hidden layer, 
+        (ii) hidden layer net input is transformed and then sent to the outputs
+        (iii) output values are copied to the context layer
         Inputs:
             x   - matrix of all input patterns
             w   - 1-d vector of weights
@@ -111,14 +106,11 @@
         if wts is not None:
             self.wp = w
         self.unpack()
-        
-        # compute hidden activations
+        # compute net input to hiddens and then squash it
         self.z = tanh(dot(x,self.w1) + dot(self.c,self.wc) + dot(ones((len(x),1)),self.b1))
-        # copy hidden vals to context units
+        # send hidden vals to output and copy to context
+        o = dot(self.z,self.w2) + dot(ones((len(self.z),1)),self.b2)
         self.c = copy.copy(self.z)
-        # compute net input to output units
-        o = dot(self.z,self.w2) + dot(ones((len(self.z),1)),self.b2)
-        
         # compute output activations
         if self.outfxn == 'linear':
             y = o
@@ -134,7 +126,7 @@
             return array(y)
 
     def train(self,x,t,N):
-        """ The calls to the various trainig algorithms.
+        """ Train net by standard backpropagation
         Inputs:
             x   - all input patterns
             t   - all target patterns
@@ -142,30 +134,47 @@
         Outputs:
             w   - new weight vector
         """
-        pass
+        for i in range(N):
+            
 
     def errfxn(self,w,x,t):
         """ Error functions for each of the output-unit activation functions.
         Inputs:
             w   - current weight vector
-            x   - current pattern input(s) (len(x) == tau)
+            x   - current pattern input(s) (len(x) == self.h)
             t   - current pattern target(s)
         """
-        pass
+        y,z = self.fwd(w,x,True)
+        if self.outfxn == 'linear':
+            # calculate & return SSE
+            err = 0.5*sum(sum(array(y-t)**2,axis=1))
+        elif self.outfxn == 'logistic':
+            # calculate & return x-entropy
+            err = -1.0*sum(sum(t*log2(y)+(1-t)*log2(1-y),axis=1))
+        elif self.outfxn == 'softmax':
+            # calculate & return entropy
+            err = -1.0*sum(sum(t*log2(y),axis=1))
+        else:
+            # this shouldn't happen, return SSE as safe default
+            err = 0.5*sum(sum(array(y-t)**2,axis=1))
+        
+        # returning a tuple of info for now...not sure why
+        return err,y,z
 
 def main():
     """ Set up a 1-2-1 SRN to solve the temporal-XOR problem from Elman 1990.
     """
     from scipy.io import read_array, write_array
-    print "Creating 1-2-1 SRN for 'temporal-XOR' (sent net.trunc to 2)"
+    print "Creating 1-2-1 SRN for 'temporal-XOR' (net.h = 2)"
     net = srn(1,2,1,'logistic',2)
+    print net
     print "\nLoading training and test sets...",
     trn_input = read_array('data/t-xor1.dat')
     trn_targs = hstack([trn_input[1:],trn_input[0]])
     tst_input = read_array('data/t-xor2.dat')
     tst_targs = hstack([tst_input[1:],tst_input[0]])
     print "done."
-    N = input("Number of times to see all patterns: ")
+    N = input("Number of iterations over training set: ")
     
     print "\nInitial error: ",net.errfxn(net.wp,tst_input,tst_targs)
     net.train(trn_input,trn_targs,N)




More information about the Scipy-svn mailing list