[Scipy-svn] r2708 - in trunk/Lib/io: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Wed Feb 14 11:41:57 EST 2007


Author: matthew.brett at gmail.com
Date: 2007-02-14 10:41:53 -0600 (Wed, 14 Feb 2007)
New Revision: 2708

Modified:
   trunk/Lib/io/npfile.py
   trunk/Lib/io/tests/test_npfile.py
Log:
Added support for -1 shape values in read_array

Modified: trunk/Lib/io/npfile.py
===================================================================
--- trunk/Lib/io/npfile.py	2007-02-12 20:50:16 UTC (rev 2707)
+++ trunk/Lib/io/npfile.py	2007-02-14 16:41:53 UTC (rev 2708)
@@ -132,6 +132,13 @@
         """Write string to file as raw bytes."""
         return self.file.write(str)
 
+    def remaining_bytes(self):
+        cur_pos = self.tell()
+        self.seek(0, 2)
+        end_pos = self.tell()
+        self.seek(cur_pos)
+        return end_pos - cur_pos
+
     def _endian_order(self, endian, order):
         ''' Housekeeping function to return endian, order from input args '''
         if endian is None:
@@ -167,13 +174,16 @@
                 data = data.byteswap()
         self.file.write(data.tostring(order=order))
         
-    def read_array(self, shape, dt, endian=None, order=None):
+    def read_array(self, dt, shape=-1, endian=None, order=None):
         '''Read data from file and return it in a numpy array.
         
         Inputs
         ------
+        dt        - dtype of array to be read
         shape     - shape of output array, or number of elements
-        dt        - dtype of array to be read
+                    (-1 as number of elements or element in shape
+                    means unknown dimension as in reshape; size
+                    of array calculated from remaining bytes in file)
         endian    - endianness of data in file
                     (can be None, 'dtype', '<', '>')
                     (if None, get from self.endian)
@@ -184,13 +194,26 @@
         arr       - array from file with given dtype (dt)
         '''
         endian, order = self._endian_order(endian, order)
+        dt = N.dtype(dt)
         try:
-            shape = tuple(shape)
+            shape = list(shape)
         except TypeError:
-            shape = (shape,)
-        dt = N.dtype(dt)
+            shape = [shape]
+        minus_ones = shape.count(-1)
+        if minus_ones == 0:
+            pass
+        elif minus_ones == 1:
+            known_dimensions_size = -N.product(shape,axis=0) * dt.itemsize
+            unknown_dimension_size, illegal = divmod(self.remaining_bytes(),
+                                                     known_dimensions_size)
+            if illegal:
+                raise ValueError("unknown dimension doesn't match filesize")
+            shape[shape.index(-1)] = unknown_dimension_size
+        else:
+            raise ValueError(
+                "illegal -1 count; can only specify one unknown dimension")
+        sz = dt.itemsize * N.product(shape)
         dt_endian = self._endian_from_dtype(dt)
-        sz = dt.itemsize * N.product(shape)
         buf = self.file.read(sz)
         arr = N.ndarray(shape=shape,
                          dtype=dt,

Modified: trunk/Lib/io/tests/test_npfile.py
===================================================================
--- trunk/Lib/io/tests/test_npfile.py	2007-02-12 20:50:16 UTC (rev 2707)
+++ trunk/Lib/io/tests/test_npfile.py	2007-02-14 16:41:53 UTC (rev 2708)
@@ -20,8 +20,7 @@
         npf.write_array(arr)
         npf.rewind()
         self.assertRaises(IOError, npf.read_array,
-                          arr.shape,
-                          arr.dtype)
+                          arr.dtype, arr.shape)
         npf.close()
         os.remove(fname)
 
@@ -48,6 +47,16 @@
         npf.rewind()
         assert str == npf.read_raw(len(str))
         
+    def test_remaining_bytes(self):
+        npf = npfile(StringIO())
+        assert npf.remaining_bytes() == 0
+        npf.write_raw('+' * 10)
+        assert npf.remaining_bytes() == 0
+        npf.rewind()
+        assert npf.remaining_bytes() == 10
+        npf.seek(5)
+        assert npf.remaining_bytes() == 5
+
     def test_read_write_array(self):
         npf = npfile(StringIO())
         arr = N.reshape(N.arange(10), (5,2))
@@ -67,25 +76,29 @@
         shp = arr.shape
         npf.write_array(arr)
         npf.rewind()
-        assert_array_equal(npf.read_array(shp, adt), arr)
+        assert_array_equal(npf.read_array(adt), arr.flatten())
         npf.rewind()
-        assert_array_equal(npf.read_array(shp, adt, endian=swapped_code),
+        assert_array_equal(npf.read_array(adt, shp), arr)
+        npf.rewind()
+        assert_array_equal(npf.read_array(adt, shp, endian=swapped_code),
                            bs_arr)
         npf.rewind()
-        assert_array_equal(npf.read_array(shp, adt, order='F'),
+        assert_array_equal(npf.read_array(adt, shp, order='F'),
                            f_arr)
         npf.rewind()
         npf.write_array(arr, order='F')
         npf.rewind()
-        assert_array_equal(npf.read_array(shp, adt),
+        assert_array_equal(npf.read_array(adt), arr.flatten('F'))
+        npf.rewind()
+        assert_array_equal(npf.read_array(adt, shp),
                            cf_arr)
         
         npf = npfile(StringIO(), endian='swapped', order='F')
         npf.write_array(arr)
         npf.rewind()
-        assert_array_equal(npf.read_array(shp, adt), arr)
+        assert_array_equal(npf.read_array(adt, shp), arr)
         npf.rewind()
-        assert_array_equal(npf.read_array(shp, adt, endian='dtype'), bs_arr)
+        assert_array_equal(npf.read_array(adt, shp, endian='dtype'), bs_arr)
         npf.rewind()
-        assert_array_equal(npf.read_array(shp, adt, order='C'), cf_arr)
+        assert_array_equal(npf.read_array(adt, shp, order='C'), cf_arr)
         




More information about the Scipy-svn mailing list