[Numpy-svn] r3151 - trunk/numpy/lib

numpy-svn at scipy.org numpy-svn at scipy.org
Wed Sep 13 23:19:13 EDT 2006


Author: oliphant
Date: 2006-09-13 22:19:04 -0500 (Wed, 13 Sep 2006)
New Revision: 3151

Modified:
   trunk/numpy/lib/index_tricks.py
Log:
Fix transpose implementation to work with higher dimensional arrays as well

Modified: trunk/numpy/lib/index_tricks.py
===================================================================
--- trunk/numpy/lib/index_tricks.py	2006-09-14 02:33:55 UTC (rev 3150)
+++ trunk/numpy/lib/index_tricks.py	2006-09-14 03:19:04 UTC (rev 3151)
@@ -207,7 +207,7 @@
         self.col = 0
 
     def __getitem__(self,key):
-        trans1d = False
+        trans1d = -1
         ndmin = 1
         if isinstance(key, str):
             frame = sys._getframe().f_back
@@ -234,8 +234,8 @@
                     newobj = _nx.arange(start, stop, step)
                 if ndmin > 1:
                     newobj = array(newobj,copy=False,ndmin=ndmin)
-                    if trans1d:
-                        newobj = newobj.T
+                    if trans1d != -1:
+                        newobj = newobj.swapaxes(-1,trans1d)
             elif isinstance(key[k],str):
                 if k != 0:
                     raise ValueError, "special directives must be the"\
@@ -250,8 +250,8 @@
                     try:
                         self.axis, ndmin = \
                                    [int(x) for x in vec[:2]]
-                        if len(vec) == 3 and vec[2] == 't':
-                            trans1d = True
+                        if len(vec) == 3:
+                            trans1d = int(vec[2])
                         continue
                     except:
                         raise ValueError, "unknown special directive"
@@ -270,8 +270,15 @@
                     tempobj = array(newobj, copy=False, subok=True)
                     newobj = array(newobj, copy=False, subok=True,
                                    ndmin=ndmin)
-                    if trans1d and tempobj.ndim == 1:
-                        newobj = newobj.T
+                    if trans1d != -1 and tempobj.ndim < ndmin:
+                        k2 = ndmin-tempobj.ndim                        
+                        if (trans1d < 0):
+                            trans1d += k2 + 1
+                        defaxes = range(ndmin)
+                        k1 = trans1d
+                        axes = defaxes[:k1] + defaxes[k2:] + \
+                               defaxes[k1:k2]
+                        newobj = newobj.transpose(axes)
                     del tempobj
             objs.append(newobj)
             if isinstance(newobj, _nx.ndarray) and not scalar:




More information about the Numpy-svn mailing list