[Numpy-svn] r3086 - in trunk/numpy: core lib lib/tests oldnumeric

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Aug 28 15:46:13 EDT 2006


Author: oliphant
Date: 2006-08-28 14:46:08 -0500 (Mon, 28 Aug 2006)
New Revision: 3086

Modified:
   trunk/numpy/core/numeric.py
   trunk/numpy/lib/function_base.py
   trunk/numpy/lib/tests/test_function_base.py
   trunk/numpy/oldnumeric/misc.py
Log:
Merge changes mistakenly added to 1.0b4 tag to the main trunk

Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2006-08-28 19:42:27 UTC (rev 3085)
+++ trunk/numpy/core/numeric.py	2006-08-28 19:46:08 UTC (rev 3086)
@@ -7,7 +7,7 @@
            'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray',
            'isfortran', 'empty_like', 'zeros_like',
            'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot',
-           'alterdot', 'restoredot', 'cross',
+           'alterdot', 'restoredot', 'cross', 'tensordot',
            'array2string', 'get_printoptions', 'set_printoptions',
            'array_repr', 'array_str', 'set_string_function',
            'little_endian', 'require',
@@ -252,7 +252,48 @@
     def restoredot():
         pass
 
+def tensordot(a, b, axes=(-1,0))
+    """tensordot returns the product for any (ndim >= 1) arrays.
 
+    r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where
+
+    the axes to be summed over are given by the axes argument.
+    the first element of the sequence determines the axis or axes
+    in arr1 to sum over and the second element in axes argument sequence
+    """
+    axes_a, axes_b = axes
+    try:
+        na = len(axes_a)
+    except TypeError:
+        axes_a = [axes_a]
+        na = 1
+    try:
+        nb = len(axes_b)
+    except TypeError:
+        axes_b = [axes_b]
+        nb = 1
+
+    a, b = asarray(a), asarray(b)
+    as = a.shape
+    bs = b.shape
+    equal = 1
+    if (na != nb): equal = 0
+    for k in xrange(na):
+        if as[axes_a[k]] != bs[axes_b[k]]:
+            equal = 0
+            break
+
+    if not equal:
+        raise ValueError, "shape-mismatch for sum"    
+    
+    olda = [ for k in aa if k not in axes_a]
+    oldb = [k for k in bs if k not in axes_b]
+
+    at = a.reshape(nd1, nd2)
+    res = dot(at, bt)
+    return res.reshape(olda + oldb)
+
+
 def _move_axis_to_0(a, axis):
     if axis == 0:
         return a

Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py	2006-08-28 19:42:27 UTC (rev 3085)
+++ trunk/numpy/lib/function_base.py	2006-08-28 19:46:08 UTC (rev 3086)
@@ -1,13 +1,13 @@
-__all__ = ['logspace', 'linspace',
+u__all__ = ['logspace', 'linspace',
            'select', 'piecewise', 'trim_zeros',
            'copy', 'iterable', #'base_repr', 'binary_repr',
            'diff', 'gradient', 'angle', 'unwrap', 'sort_complex', 'disp',
-           'unique', 'extract', 'insert', 'nansum', 'nanmax', 'nanargmax',
+           'unique', 'extract', 'place', 'nansum', 'nanmax', 'nanargmax',
            'nanargmin', 'nanmin', 'vectorize', 'asarray_chkfinite', 'average',
            'histogram', 'bincount', 'digitize', 'cov', 'corrcoef', 'msort',
            'median', 'sinc', 'hamming', 'hanning', 'bartlett', 'blackman',
            'kaiser', 'trapz', 'i0', 'add_newdoc', 'add_docstring', 'meshgrid',
-           'deletefrom', 'insertinto', 'appendonto'
+           'delete', 'insert', 'append'
            ]
 
 import types
@@ -545,7 +545,7 @@
     """
     return _nx.take(ravel(arr), nonzero(ravel(condition))[0])
 
-def insert(arr, mask, vals):
+def place(arr, mask, vals):
     """Similar to putmask arr[mask] = vals but the 1D array vals has the
     same number of elements as the non-zero values of mask. Inverse of
     extract.
@@ -1011,7 +1011,7 @@
     Y = y.repeat(numCols, axis=1)
     return X, Y
 
-def deletefrom(arr, obj, axis=None):
+def delete(arr, obj, axis=None):
     """Return a new array with sub-arrays along an axis deleted.
 
     Return a new array with the sub-arrays (i.e. rows or columns)
@@ -1117,7 +1117,7 @@
     else:
         return new
 
-def insertinto(arr, obj, values, axis=None):
+def insert(arr, obj, values, axis=None):
     """Return a new array with values inserted along the given axis
     before the given indices
 
@@ -1205,7 +1205,7 @@
         return wrap(new)
     return new
 
-def appendonto(arr, values, axis=None):
+def append(arr, values, axis=None):
     """Append to the end of an array along axis (ravel first if None)
     """
     arr = asanyarray(arr)
@@ -1215,3 +1215,4 @@
         values = ravel(values)
         axis = arr.ndim-1
     return concatenate((arr, values), axis=axis)
+

Modified: trunk/numpy/lib/tests/test_function_base.py
===================================================================
--- trunk/numpy/lib/tests/test_function_base.py	2006-08-28 19:42:27 UTC (rev 3085)
+++ trunk/numpy/lib/tests/test_function_base.py	2006-08-28 19:46:08 UTC (rev 3086)
@@ -237,17 +237,17 @@
         a = array([1,3,2,1,2,3,3])
         b = extract(a>1,a)
         assert_array_equal(b,[3,2,2,3,3])
-    def check_insert(self):
+    def check_place(self):
         a = array([1,4,3,2,5,8,7])
-        insert(a,[0,1,0,1,0,1,0],[2,4,6])
+        place(a,[0,1,0,1,0,1,0],[2,4,6])
         assert_array_equal(a,[1,2,3,4,5,6,7])
     def check_both(self):
         a = rand(10)
         mask = a > 0.5
         ac = a.copy()
         c = extract(mask, a)
-        insert(a,mask,0)
-        insert(a,mask,c)
+        place(a,mask,0)
+        place(a,mask,c)
         assert_array_equal(a,ac)
 
 class test_vectorize(NumpyTestCase):

Modified: trunk/numpy/oldnumeric/misc.py
===================================================================
--- trunk/numpy/oldnumeric/misc.py	2006-08-28 19:42:27 UTC (rev 3085)
+++ trunk/numpy/oldnumeric/misc.py	2006-08-28 19:46:08 UTC (rev 3086)
@@ -9,7 +9,7 @@
            'searchsorted', 'put', 'fromfunction', 'copy', 'resize',
            'array_repr', 'e', 'StringIO', 'pickle',
            'argsort', 'convolve', 'loads', 'cross_correlate',
-           'Pickler', 'dot', 'outerproduct', 'innerproduct']
+           'Pickler', 'dot', 'outerproduct', 'innerproduct', 'insert']
 
 import types
 import StringIO
@@ -23,7 +23,8 @@
      choose, swapaxes, array_str, array_repr, e, pi, \
      fromfunction, resize, around, concatenate, vdot, transpose, \
      diagonal, searchsorted, put, argsort, convolve, dot, \
-     outer as outerproduct, inner as innerproduct, correlate as cross_correlate
+     outer as outerproduct, inner as innerproduct, correlate as cross_correlate, \
+     place as insert
 
 from array_printer import array2string
 




More information about the Numpy-svn mailing list