[Numpy-svn] r3150 - in trunk/numpy/lib: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Wed Sep 13 22:34:05 EDT 2006


Author: oliphant
Date: 2006-09-13 21:33:55 -0500 (Wed, 13 Sep 2006)
New Revision: 3150

Modified:
   trunk/numpy/lib/function_base.py
   trunk/numpy/lib/index_tricks.py
   trunk/numpy/lib/tests/test_index_tricks.py
Log:
Fix up r_ so you can specify the minimum number of dimensions to force arrays to and allow alteration of the concatenation axis and whether or not to transpose 1d arrays

Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py	2006-09-14 01:49:10 UTC (rev 3149)
+++ trunk/numpy/lib/function_base.py	2006-09-14 02:33:55 UTC (rev 3150)
@@ -103,7 +103,7 @@
     else:
         return n, bins
 
-def  histogramnd(sample, bins=10, range=None, normed=False):
+def histogramnd(sample, bins=10, range=None, normed=False):
     """histogramnd(sample, bins = 10, range = None, normed = False) -> H, edges
     
     Return the N-dimensional histogram computed from sample.

Modified: trunk/numpy/lib/index_tricks.py
===================================================================
--- trunk/numpy/lib/index_tricks.py	2006-09-14 01:49:10 UTC (rev 3149)
+++ trunk/numpy/lib/index_tricks.py	2006-09-14 02:33:55 UTC (rev 3150)
@@ -10,7 +10,7 @@
 import sys
 import types
 import numpy.core.numeric as _nx
-from numpy.core.numeric import asarray, ScalarType
+from numpy.core.numeric import asarray, ScalarType, array
 
 import function_base
 import numpy.core.defmatrix as matrix
@@ -207,6 +207,8 @@
         self.col = 0
 
     def __getitem__(self,key):
+        trans1d = False
+        ndmin = 1
         if isinstance(key, str):
             frame = sys._getframe().f_back
             mymat = matrix.bmat(key,frame.f_globals,frame.f_locals)
@@ -230,22 +232,47 @@
                     newobj = function_base.linspace(start, stop, num=size)
                 else:
                     newobj = _nx.arange(start, stop, step)
-            elif type(key[k]) is str:
-                if (key[k] in 'rc'):
+                if ndmin > 1:
+                    newobj = array(newobj,copy=False,ndmin=ndmin)
+                    if trans1d:
+                        newobj = newobj.T
+            elif isinstance(key[k],str):
+                if k != 0:
+                    raise ValueError, "special directives must be the"\
+                          "first entry."
+                key0 = key[0]
+                if key0 in 'rc':
                     self.matrix = True
-                    self.col = (key[k] == 'c')
+                    self.col = (key0 == 'c')
                     continue
+                if ',' in key0:
+                    vec = key0.split(',')
+                    try:
+                        self.axis, ndmin = \
+                                   [int(x) for x in vec[:2]]
+                        if len(vec) == 3 and vec[2] == 't':
+                            trans1d = True
+                        continue
+                    except:
+                        raise ValueError, "unknown special directive"
                 try:
                     self.axis = int(key[k])
                     continue
                 except (ValueError, TypeError):
                     raise ValueError, "unknown special directive"
             elif type(key[k]) in ScalarType:
-                newobj = asarray([key[k]])
+                newobj = array(key[k],ndmin=ndmin)
                 scalars.append(k)
                 scalar = True
             else:
                 newobj = key[k]
+                if ndmin > 1:
+                    tempobj = array(newobj, copy=False, subok=True)
+                    newobj = array(newobj, copy=False, subok=True,
+                                   ndmin=ndmin)
+                    if trans1d and tempobj.ndim == 1:
+                        newobj = newobj.T
+                    del tempobj
             objs.append(newobj)
             if isinstance(newobj, _nx.ndarray) and not scalar:
                 if final_dtypedescr is None:
@@ -286,13 +313,13 @@
 class c_class(concatenator):
     """Translates slice objects to concatenation along the second axis.
 
-       This is deprecated.  Use r_[...,'-1']
+       This is deprecated.  Use r_['-1',...]
     """
     def __init__(self):
         concatenator.__init__(self, -1)
 
     def __getitem__(self, obj):
-        warnings.warn("c_ is deprecated use r_[...,'-1']")
+        warnings.warn("c_ is deprecated use r_['-1',...]")
         return concatenator.__getitem__(self, obj)
 
 c_ = c_class()

Modified: trunk/numpy/lib/tests/test_index_tricks.py
===================================================================
--- trunk/numpy/lib/tests/test_index_tricks.py	2006-09-14 01:49:10 UTC (rev 3149)
+++ trunk/numpy/lib/tests/test_index_tricks.py	2006-09-14 02:33:55 UTC (rev 3150)
@@ -39,7 +39,7 @@
     def check_2d(self):
         b = rand(5,5)
         c = rand(5,5)
-        d = r_[b,c,'1']  # append columns
+        d = r_['1',b,c]  # append columns
         assert(d.shape == (5,10))
         assert_array_equal(d[:,:5],b)
         assert_array_equal(d[:,5:],c)




More information about the Numpy-svn mailing list