[Numpy-svn] r3085 - in tags/1.0b4/numpy: core lib

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Aug 28 15:42:31 EDT 2006


Author: oliphant
Date: 2006-08-28 14:42:27 -0500 (Mon, 28 Aug 2006)
New Revision: 3085

Modified:
   tags/1.0b4/numpy/core/numeric.py
   tags/1.0b4/numpy/lib/function_base.py
Log:
Update tensordot.

Modified: tags/1.0b4/numpy/core/numeric.py
===================================================================
--- tags/1.0b4/numpy/core/numeric.py	2006-08-28 05:56:49 UTC (rev 3084)
+++ tags/1.0b4/numpy/core/numeric.py	2006-08-28 19:42:27 UTC (rev 3085)
@@ -7,7 +7,7 @@
            'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray',
            'isfortran', 'empty_like', 'zeros_like',
            'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot',
-           'alterdot', 'restoredot', 'cross',
+           'alterdot', 'restoredot', 'cross', 'tensordot',
            'array2string', 'get_printoptions', 'set_printoptions',
            'array_repr', 'array_str', 'set_string_function',
            'little_endian', 'require',
@@ -252,7 +252,48 @@
     def restoredot():
         pass
 
+def tensordot(a, b, axes=(-1,0))
+    """tensordot returns the product for any (ndim >= 1) arrays.
 
+    r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where
+
+    the axes to be summed over are given by the axes argument.
+    the first element of the sequence determines the axis or axes
+    in arr1 to sum over and the second element in axes argument sequence
+    """
+    axes_a, axes_b = axes
+    try:
+        na = len(axes_a)
+    except TypeError:
+        axes_a = [axes_a]
+        na = 1
+    try:
+        nb = len(axes_b)
+    except TypeError:
+        axes_b = [axes_b]
+        nb = 1
+
+    a, b = asarray(a), asarray(b)
+    as = a.shape
+    bs = b.shape
+    equal = 1
+    if (na != nb): equal = 0
+    for k in xrange(na):
+        if as[axes_a[k]] != bs[axes_b[k]]:
+            equal = 0
+            break
+
+    if not equal:
+        raise ValueError, "shape-mismatch for sum"    
+    
+    olda = [ for k in aa if k not in axes_a]
+    oldb = [k for k in bs if k not in axes_b]
+
+    at = a.reshape(nd1, nd2)
+    res = dot(at, bt)
+    return res.reshape(olda + oldb)
+
+
 def _move_axis_to_0(a, axis):
     if axis == 0:
         return a

Modified: tags/1.0b4/numpy/lib/function_base.py
===================================================================
--- tags/1.0b4/numpy/lib/function_base.py	2006-08-28 05:56:49 UTC (rev 3084)
+++ tags/1.0b4/numpy/lib/function_base.py	2006-08-28 19:42:27 UTC (rev 3085)
@@ -7,7 +7,7 @@
            'histogram', 'bincount', 'digitize', 'cov', 'corrcoef', 'msort',
            'median', 'sinc', 'hamming', 'hanning', 'bartlett', 'blackman',
            'kaiser', 'trapz', 'i0', 'add_newdoc', 'add_docstring', 'meshgrid',
-           'delete', 'insert', 'append', 'tensordot'
+           'delete', 'insert', 'append'
            ]
 
 import types
@@ -1216,11 +1216,3 @@
         axis = arr.ndim-1
     return concatenate((arr, values), axis=axis)
 
-def tensordot(arr1, arr2, axes1=-1, axes2=0):
-    """tensordot returns the product for any (ndim >= 1) arrays.
-
-    r_{xxx, yyy} = \sum_k arr1_{xxx,k} arr2_{k,yyy} where
-    the axes of k 
-    """
-    #FIXME
-    pass




More information about the Numpy-svn mailing list