[Scipy-svn] r2708 - in trunk/Lib/io: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Wed Feb 14 11:41:57 EST 2007
Author: matthew.brett at gmail.com
Date: 2007-02-14 10:41:53 -0600 (Wed, 14 Feb 2007)
New Revision: 2708
Modified:
trunk/Lib/io/npfile.py
trunk/Lib/io/tests/test_npfile.py
Log:
Added support for -1 shape values in read_array
Modified: trunk/Lib/io/npfile.py
===================================================================
--- trunk/Lib/io/npfile.py 2007-02-12 20:50:16 UTC (rev 2707)
+++ trunk/Lib/io/npfile.py 2007-02-14 16:41:53 UTC (rev 2708)
@@ -132,6 +132,13 @@
"""Write string to file as raw bytes."""
return self.file.write(str)
+ def remaining_bytes(self):
+ cur_pos = self.tell()
+ self.seek(0, 2)
+ end_pos = self.tell()
+ self.seek(cur_pos)
+ return end_pos - cur_pos
+
def _endian_order(self, endian, order):
''' Housekeeping function to return endian, order from input args '''
if endian is None:
@@ -167,13 +174,16 @@
data = data.byteswap()
self.file.write(data.tostring(order=order))
- def read_array(self, shape, dt, endian=None, order=None):
+ def read_array(self, dt, shape=-1, endian=None, order=None):
'''Read data from file and return it in a numpy array.
Inputs
------
+ dt - dtype of array to be read
shape - shape of output array, or number of elements
- dt - dtype of array to be read
+ (-1 as number of elements or element in shape
+ means unknown dimension as in reshape; size
+ of array calculated from remaining bytes in file)
endian - endianness of data in file
(can be None, 'dtype', '<', '>')
(if None, get from self.endian)
@@ -184,13 +194,26 @@
arr - array from file with given dtype (dt)
'''
endian, order = self._endian_order(endian, order)
+ dt = N.dtype(dt)
try:
- shape = tuple(shape)
+ shape = list(shape)
except TypeError:
- shape = (shape,)
- dt = N.dtype(dt)
+ shape = [shape]
+ minus_ones = shape.count(-1)
+ if minus_ones == 0:
+ pass
+ elif minus_ones == 1:
+ known_dimensions_size = -N.product(shape,axis=0) * dt.itemsize
+ unknown_dimension_size, illegal = divmod(self.remaining_bytes(),
+ known_dimensions_size)
+ if illegal:
+ raise ValueError("unknown dimension doesn't match filesize")
+ shape[shape.index(-1)] = unknown_dimension_size
+ else:
+ raise ValueError(
+ "illegal -1 count; can only specify one unknown dimension")
+ sz = dt.itemsize * N.product(shape)
dt_endian = self._endian_from_dtype(dt)
- sz = dt.itemsize * N.product(shape)
buf = self.file.read(sz)
arr = N.ndarray(shape=shape,
dtype=dt,
Modified: trunk/Lib/io/tests/test_npfile.py
===================================================================
--- trunk/Lib/io/tests/test_npfile.py 2007-02-12 20:50:16 UTC (rev 2707)
+++ trunk/Lib/io/tests/test_npfile.py 2007-02-14 16:41:53 UTC (rev 2708)
@@ -20,8 +20,7 @@
npf.write_array(arr)
npf.rewind()
self.assertRaises(IOError, npf.read_array,
- arr.shape,
- arr.dtype)
+ arr.dtype, arr.shape)
npf.close()
os.remove(fname)
@@ -48,6 +47,16 @@
npf.rewind()
assert str == npf.read_raw(len(str))
+ def test_remaining_bytes(self):
+ npf = npfile(StringIO())
+ assert npf.remaining_bytes() == 0
+ npf.write_raw('+' * 10)
+ assert npf.remaining_bytes() == 0
+ npf.rewind()
+ assert npf.remaining_bytes() == 10
+ npf.seek(5)
+ assert npf.remaining_bytes() == 5
+
def test_read_write_array(self):
npf = npfile(StringIO())
arr = N.reshape(N.arange(10), (5,2))
@@ -67,25 +76,29 @@
shp = arr.shape
npf.write_array(arr)
npf.rewind()
- assert_array_equal(npf.read_array(shp, adt), arr)
+ assert_array_equal(npf.read_array(adt), arr.flatten())
npf.rewind()
- assert_array_equal(npf.read_array(shp, adt, endian=swapped_code),
+ assert_array_equal(npf.read_array(adt, shp), arr)
+ npf.rewind()
+ assert_array_equal(npf.read_array(adt, shp, endian=swapped_code),
bs_arr)
npf.rewind()
- assert_array_equal(npf.read_array(shp, adt, order='F'),
+ assert_array_equal(npf.read_array(adt, shp, order='F'),
f_arr)
npf.rewind()
npf.write_array(arr, order='F')
npf.rewind()
- assert_array_equal(npf.read_array(shp, adt),
+ assert_array_equal(npf.read_array(adt), arr.flatten('F'))
+ npf.rewind()
+ assert_array_equal(npf.read_array(adt, shp),
cf_arr)
npf = npfile(StringIO(), endian='swapped', order='F')
npf.write_array(arr)
npf.rewind()
- assert_array_equal(npf.read_array(shp, adt), arr)
+ assert_array_equal(npf.read_array(adt, shp), arr)
npf.rewind()
- assert_array_equal(npf.read_array(shp, adt, endian='dtype'), bs_arr)
+ assert_array_equal(npf.read_array(adt, shp, endian='dtype'), bs_arr)
npf.rewind()
- assert_array_equal(npf.read_array(shp, adt, order='C'), cf_arr)
+ assert_array_equal(npf.read_array(adt, shp, order='C'), cf_arr)
More information about the Scipy-svn
mailing list