[Scipy-svn] r5029 - in trunk/scipy/io/matlab: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Sun Nov 9 00:32:51 EST 2008
Author: matthew.brett at gmail.com
Date: 2008-11-08 23:32:46 -0600 (Sat, 08 Nov 2008)
New Revision: 5029
Modified:
trunk/scipy/io/matlab/mio5.py
trunk/scipy/io/matlab/tests/gen_mat5files.m
trunk/scipy/io/matlab/tests/test_mio.py
Log:
Added all load checks to roundtrip checks, fixed transpose in save of char data in matlab 5, all 51 tests now pass for me
Modified: trunk/scipy/io/matlab/mio5.py
===================================================================
--- trunk/scipy/io/matlab/mio5.py 2008-11-09 04:13:26 UTC (rev 5028)
+++ trunk/scipy/io/matlab/mio5.py 2008-11-09 05:32:46 UTC (rev 5029)
@@ -728,14 +728,22 @@
is_global=False,
is_complex=False,
is_logical=False,
- nzmax=0):
+ nzmax=0,
+ shape=None):
''' Write header for given data options
mclass - mat5 matrix class
is_global - True if matrix is global
is_complex - True if matrix is complex
is_logical - True if matrix is logical
nzmax - max non zero elements for sparse arrays
+ shape : {None, tuple} optional
+ directly specify shape if this is not the same as for
+ self.arr
'''
+ if shape is None:
+ shape = self.arr.shape
+ if len(shape) < 2:
+ shape = shape + (0,) * (len(shape)-2)
self._mat_tag_pos = self.file_stream.tell()
self.write_dtype(self.mat_tag)
# write array flags (complex, global, logical, class, nzmax)
@@ -746,13 +754,7 @@
af['flags_class'] = mclass | flags << 8
af['nzmax'] = nzmax
self.write_dtype(af)
- # write array shape
- if self.arr.ndim < 2:
- new_arr = np.atleast_2d(self.arr)
- if type(new_arr) != type(self.arr):
- raise ValueError("Array should be 2-dimensional.")
- self.arr = new_arr
- self.write_element(np.array(self.arr.shape, dtype='i4'))
+ self.write_element(np.array(shape, dtype='i4'))
# write name
self.write_element(np.array([ord(c) for c in self.name], 'i1'))
@@ -786,22 +788,33 @@
self.write_element(self.arr)
self.update_matrix_tag()
+
class Mat5CharWriter(Mat5MatrixWriter):
codec='ascii'
def write(self):
self.arr_to_chars()
- self.write_header(mclass=mxCHAR_CLASS)
+ # We have to write the shape directly, because we are going
+ # recode the characters, and the resulting stream of chars
+ # may have a different length
+ shape = self.arr.shape
+ self.write_header(mclass=mxCHAR_CLASS,shape=shape)
+ # We need to do our own transpose (not using the normal
+ # write routines that do this for us)
+ arr = self.arr.T.copy()
if self.arr.dtype.kind == 'U':
# Recode unicode using self.codec
- n_chars = np.product(self.arr.shape)
+ n_chars = np.product(shape)
st_arr = np.ndarray(shape=(),
dtype=self.arr_dtype_number(n_chars),
- buffer=self.arr)
+ buffer=arr)
st = st_arr.item().encode(self.codec)
- self.arr = np.ndarray(shape=(len(st)), dtype='u1', buffer=st)
- self.write_element(self.arr,mdtype=miUTF8)
+ arr = np.ndarray(shape=(len(st),),
+ dtype='u1',
+ buffer=st)
+ self.write_element(arr, mdtype=miUTF8)
self.update_matrix_tag()
+
class Mat5UniCharWriter(Mat5CharWriter):
codec='UTF8'
@@ -976,17 +989,20 @@
continue
is_global = name in self.global_vars
self.writer_getter.rewind()
- self.writer_getter.matrix_writer_factory(
+ mat_writer = self.writer_getter.matrix_writer_factory(
var,
name,
- is_global,
- ).write()
+ is_global)
+ mat_writer.write()
stream = self.writer_getter.stream
+ bytes_written = stream.tell()
+ stream.seek(0)
+ out_str = stream.read(bytes_written)
if self.do_compression:
- str = zlib.compress(stream.getvalue(stream.tell()))
+ out_str = zlib.compress(out_str)
tag = np.empty((), mdtypes_template['tag_full'])
tag['mdtype'] = miCOMPRESSED
tag['byte_count'] = len(str)
- self.file_stream.write(tag.tostring() + str)
+ self.file_stream.write(tag.tostring() + out_str)
else:
- self.file_stream.write(stream.getvalue(stream.tell()))
+ self.file_stream.write(out_str)
Modified: trunk/scipy/io/matlab/tests/gen_mat5files.m
===================================================================
--- trunk/scipy/io/matlab/tests/gen_mat5files.m 2008-11-09 04:13:26 UTC (rev 5028)
+++ trunk/scipy/io/matlab/tests/gen_mat5files.m 2008-11-09 05:32:46 UTC (rev 5029)
@@ -89,4 +89,8 @@
fclose(fid);
save_matfile('testunicode', native2unicode(from_japan, 'utf-8'));
end
-
\ No newline at end of file
+
+% sparse float
+
+
+% sparse complex
Modified: trunk/scipy/io/matlab/tests/test_mio.py
===================================================================
--- trunk/scipy/io/matlab/tests/test_mio.py 2008-11-09 04:13:26 UTC (rev 5028)
+++ trunk/scipy/io/matlab/tests/test_mio.py 2008-11-09 05:32:46 UTC (rev 5029)
@@ -25,87 +25,16 @@
test_data_path = join(dirname(__file__), 'data')
def mlarr(*args, **kwargs):
- ''' Return matlab-compatible 2D array'''
+ ''' Convenience function to return matlab-compatible 2D array
+ Note that matlab writes empty shape as (0,0) - replicated here
+ '''
arr = np.array(*args, **kwargs)
if arr.size:
return np.atleast_2d(arr)
# empty elements return as shape (0,0)
return arr.reshape((0,0))
-def _check_level(label, expected, actual):
- """ Check one level of a potentially nested array """
- if SP.issparse(expected): # allow different types of sparse matrices
- assert SP.issparse(actual)
- assert_array_almost_equal(actual.todense(),
- expected.todense(),
- err_msg = label,
- decimal = 5)
- return
- # Check types are as expected
- typex = type(expected)
- typac = type(actual)
- assert typex is typac, \
- "Expected type %s, got %s at %s" % (typex, typac, label)
- # object, as container for matlab objects
- if isinstance(expected, MatlabObject):
- ex_fields = dir(expected)
- ac_fields = dir(actual)
- for k in ex_fields:
- if k.startswith('__') and k.endswith('__'):
- continue
- assert k in ac_fields, \
- "Missing expected property %s for %s" % (k, label)
- ev = expected.__dict__[k]
- v = actual.__dict__[k]
- level_label = "%s, property %s, " % (label, k)
- _check_level(level_label, ev, v)
- return
- # A field in a record array may not be an ndarray
- # A scalar from a record array will be type np.void
- if not isinstance(expected, (np.void, np.ndarray)):
- assert_equal(expected, actual)
- return
- # This is an ndarray
- assert_true(expected.shape == actual.shape,
- msg='Expected shape %s, got %s at %s' % (expected.shape,
- actual.shape,
- label)
- )
- ex_dtype = expected.dtype
- if ex_dtype.hasobject: # array of objects
- for i, ev in enumerate(expected):
- level_label = "%s, [%d], " % (label, i)
- _check_level(level_label, ev, actual[i])
- return
- if ex_dtype.fields: # probably recarray
- for fn in ex_dtype.fields:
- level_label = "%s, field %s, " % (label, fn)
- _check_level(level_label,
- expected[fn], actual[fn])
- return
- if ex_dtype.type in (np.unicode, # string
- np.unicode_):
- assert_equal(actual, expected, err_msg=label)
- return
- # Something numeric
- assert_array_almost_equal(actual, expected, err_msg=label, decimal=5)
-def _check_case(name, files, case):
- for file_name in files:
- matdict = loadmat(file_name, struct_as_record=True)
- label = "test %s; file %s" % (name, file_name)
- for k, expected in case.items():
- k_label = "%s, variable %s" % (label, k)
- assert k in matdict, "Missing key at %s" % k_label
- _check_level(k_label, expected, matdict[k])
-
-# Round trip tests
-def _rt_check_case(name, expected, format):
- mat_stream = StringIO()
- savemat(mat_stream, expected, format=format)
- mat_stream.seek(0)
- _check_case(name, [mat_stream], expected)
-
# Define cases to test
theta = np.pi/4*np.arange(9,dtype=float).reshape(1,9)
case_table4 = [
@@ -181,20 +110,6 @@
'expected': {
'test3dmatrix': np.transpose(np.reshape(range(1,25), (4,3,2)))}
})
-case_table5_rt = [
- {'name': '3dmatrix',
- 'expected': {
- 'test3dmatrix': np.transpose(np.reshape(range(1,25), (4,3,2)))}
- },
- {'name': 'sparsefloat',
- 'expected': {'testsparsefloat':
- SP.coo_matrix(array([[1,0,2],[0,-3.5,0]]))},
- },
- {'name': 'sparsecomplex',
- 'expected': {'testsparsefloat':
- SP.coo_matrix(array([[-1+2j,0,2],[0,-3j,0]]))},
- },
- ]
st_sub_arr = array([np.sqrt(2),np.exp(1),np.pi]).reshape(1,3)
dtype = [(n, object) for n in ['stringfield', 'doublefield', 'complexfield']]
st1 = np.zeros((1,1), dtype)
@@ -254,7 +169,95 @@
{'name': 'unicode',
'expected': {'testunicode': array([u_str])}
})
+# These should also have matlab load equivalents, but I can't get to matlab at the moment
+case_table5_rt = case_table5[:]
+case_table5_rt.append(
+ {'name': 'sparsefloat',
+ 'expected': {'testsparsefloat':
+ SP.coo_matrix(array([[1,0,2],[0,-3.5,0]]))},
+ })
+case_table5_rt.append(
+ {'name': 'sparsecomplex',
+ 'expected': {'testsparsecomplex':
+ SP.coo_matrix(array([[-1+2j,0,2],[0,-3j,0]]))},
+ })
+
+def _check_level(label, expected, actual):
+ """ Check one level of a potentially nested array """
+ if SP.issparse(expected): # allow different types of sparse matrices
+ assert SP.issparse(actual)
+ assert_array_almost_equal(actual.todense(),
+ expected.todense(),
+ err_msg = label,
+ decimal = 5)
+ return
+ # Check types are as expected
+ typex = type(expected)
+ typac = type(actual)
+ assert typex is typac, \
+ "Expected type %s, got %s at %s" % (typex, typac, label)
+ # object, as container for matlab objects
+ if isinstance(expected, MatlabObject):
+ ex_fields = dir(expected)
+ ac_fields = dir(actual)
+ for k in ex_fields:
+ if k.startswith('__') and k.endswith('__'):
+ continue
+ assert k in ac_fields, \
+ "Missing expected property %s for %s" % (k, label)
+ ev = expected.__dict__[k]
+ v = actual.__dict__[k]
+ level_label = "%s, property %s, " % (label, k)
+ _check_level(level_label, ev, v)
+ return
+ # A field in a record array may not be an ndarray
+ # A scalar from a record array will be type np.void
+ if not isinstance(expected, (np.void, np.ndarray)):
+ assert_equal(expected, actual)
+ return
+ # This is an ndarray
+ assert_true(expected.shape == actual.shape,
+ msg='Expected shape %s, got %s at %s' % (expected.shape,
+ actual.shape,
+ label)
+ )
+ ex_dtype = expected.dtype
+ if ex_dtype.hasobject: # array of objects
+ for i, ev in enumerate(expected):
+ level_label = "%s, [%d], " % (label, i)
+ _check_level(level_label, ev, actual[i])
+ return
+ if ex_dtype.fields: # probably recarray
+ for fn in ex_dtype.fields:
+ level_label = "%s, field %s, " % (label, fn)
+ _check_level(level_label,
+ expected[fn], actual[fn])
+ return
+ if ex_dtype.type in (np.unicode, # string
+ np.unicode_):
+ assert_equal(actual, expected, err_msg=label)
+ return
+ # Something numeric
+ assert_array_almost_equal(actual, expected, err_msg=label, decimal=5)
+
+def _load_check_case(name, files, case):
+ for file_name in files:
+ matdict = loadmat(file_name, struct_as_record=True)
+ label = "test %s; file %s" % (name, file_name)
+ for k, expected in case.items():
+ k_label = "%s, variable %s" % (label, k)
+ assert k in matdict, "Missing key at %s" % k_label
+ _check_level(k_label, expected, matdict[k])
+
+# Round trip tests
+def _rt_check_case(name, expected, format):
+ mat_stream = StringIO()
+ savemat(mat_stream, expected, format=format)
+ mat_stream.seek(0)
+ _load_check_case(name, [mat_stream], expected)
+
+
# generator for load tests
def test_load():
for case in case_table4 + case_table5:
@@ -263,7 +266,7 @@
filt = join(test_data_path, 'test%s_*.mat' % name)
files = glob(filt)
assert files, "No files for test %s using filter %s" % (name, filt)
- yield _check_case, name, files, expected
+ yield _load_check_case, name, files, expected
# generator for round trip tests
def test_round_trip():
More information about the Scipy-svn
mailing list