[Scipy-svn] r5246 - in trunk/scipy/io/matlab: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Dec 13 21:30:24 EST 2008


Author: matthew.brett at gmail.com
Date: 2008-12-13 20:30:09 -0600 (Sat, 13 Dec 2008)
New Revision: 5246

Modified:
   trunk/scipy/io/matlab/byteordercodes.py
   trunk/scipy/io/matlab/mio5.py
   trunk/scipy/io/matlab/tests/test_mio.py
Log:
Added tests, and reimplemented fixes, from Lee Kamnetsky; refactored option passing somewhat; removed remaining bare assert statements in tests

Modified: trunk/scipy/io/matlab/byteordercodes.py
===================================================================
--- trunk/scipy/io/matlab/byteordercodes.py	2008-12-13 21:19:02 UTC (rev 5245)
+++ trunk/scipy/io/matlab/byteordercodes.py	2008-12-14 02:30:09 UTC (rev 5246)
@@ -36,7 +36,6 @@
     Examples
     --------
     >>> import sys
-    >>> from imagers.byteorder import to_numpy_code, sys_is_le
     >>> sys_is_le == (sys.byteorder == 'little')
     True
     >>> to_numpy_code('big')

Modified: trunk/scipy/io/matlab/mio5.py
===================================================================
--- trunk/scipy/io/matlab/mio5.py	2008-12-13 21:19:02 UTC (rev 5245)
+++ trunk/scipy/io/matlab/mio5.py	2008-12-14 02:30:09 UTC (rev 5246)
@@ -622,10 +622,10 @@
 
 
 class Mat5MatrixWriter(MatStreamWriter):
-
+    ''' Generic matlab matrix writing class '''
     mat_tag = np.zeros((), mdtypes_template['tag_full'])
     mat_tag['mdtype'] = miMATRIX
-
+    default_mclass = None # default class for header writing
     def __init__(self,
                  file_stream,
                  arr,
@@ -670,7 +670,7 @@
         # pad to next 64-bit boundary
         self.write_bytes(np.zeros((padding,),'u1'))
 
-    def write_header(self, mclass,
+    def write_header(self, mclass=None,
                      is_global=False,
                      is_complex=False,
                      is_logical=False,
@@ -686,6 +686,8 @@
             directly specify shape if this is not the same as for
             self.arr
         '''
+        if mclass is None:
+            mclass = self.default_mclass
         if shape is None:
             shape = self.arr.shape
             if len(shape) < 2:
@@ -712,10 +714,18 @@
         self.file_stream.seek(curr_pos)
 
     def write(self):
-        assert False, 'Not implemented'
+        raise NotImplementedError
 
+    def make_writer_getter(self):
+        ''' Make writer getter for this stream '''
+        return Mat5WriterGetter(self.file_stream,
+                                self.unicode_strings,
+                                self.long_field_names)
 
+
+
 class Mat5NumericWriter(Mat5MatrixWriter):
+    default_mclass = None # can be any numeric type
     def write(self):
         imagf = self.arr.dtype.kind == 'c'
         try:
@@ -737,13 +747,14 @@
 
 class Mat5CharWriter(Mat5MatrixWriter):
     codec='ascii'
+    default_mclass = mxCHAR_CLASS
     def write(self):
         self.arr_to_chars()
         # We have to write the shape directly, because we are going
         # recode the characters, and the resulting stream of chars
         # may have a different length
         shape = self.arr.shape
-        self.write_header(mclass=mxCHAR_CLASS,shape=shape)
+        self.write_header(shape=shape)
         # We need to do our own transpose (not using the normal
         # write routines that do this for us)
         arr = self.arr.T.copy()
@@ -766,6 +777,7 @@
 
 
 class Mat5SparseWriter(Mat5MatrixWriter):
+    default_mclass = mxSPARSE_CLASS
     def write(self):
         ''' Sparse matrices are 2D
         '''
@@ -773,8 +785,7 @@
         A.sort_indices()     # MATLAB expects sorted row indices
         is_complex = (A.dtype.kind == 'c')
         nz = A.nnz
-        self.write_header(mclass=mxSPARSE_CLASS,
-                          is_complex=is_complex,
+        self.write_header(is_complex=is_complex,
                           nzmax=nz)
         self.write_element(A.indices.astype('i4'))
         self.write_element(A.indptr.astype('i4'))
@@ -785,49 +796,51 @@
 
 
 class Mat5CellWriter(Mat5MatrixWriter):
+    default_mclass = mxCELL_CLASS
     def write(self):
-        self.write_header(mclass=mxCELL_CLASS)
+        self.write_header()
+        self._write_items()
+        
+    def _write_items(self):
         # loop over data, column major
         A = np.atleast_2d(self.arr).flatten('F')
-        MWG = Mat5WriterGetter(self.file_stream, self.unicode_strings)
+        MWG = self.make_writer_getter()
         for el in A:
             MW = MWG.matrix_writer_factory(el, '')
             MW.write()
         self.update_matrix_tag()
 
 
-class Mat5FunctionWriter(Mat5MatrixWriter):
-    def write(self):
-        self.write_header(mclass=mxFUNCTION_CLASS)
-        # loop over data, column major
-        A = np.atleast_2d(self.arr).flatten('F')
-        MWG = Mat5WriterGetter(self.file_stream, self.unicode_strings)
-        for el in A:
-            MW = MWG.matrix_writer_factory(el, '')
-            MW.write()
-        self.update_matrix_tag()
+class Mat5FunctionWriter(Mat5CellWriter):
+    ''' class to write matlab functions
 
+    Only differs from cell writing in mx class in header '''
+    default_mclass = mxFUNCTION_CLASS
 
-class Mat5StructWriter(Mat5MatrixWriter):
-    def write(self):
-        self.write_header(mclass=mxSTRUCT_CLASS)
-        self.write_fields()
 
-    def write_fields(self):
+class Mat5StructWriter(Mat5CellWriter):
+    ''' class to write matlab structs
+
+    Differs from cell writing class in writing field names,
+    and in mx class
+    '''
+    default_mclass = mxSTRUCT_CLASS
+
+    def _write_items(self):
         # write fieldnames
         fieldnames = [f[0] for f in self.arr.dtype.descr]
         length = max([len(fieldname) for fieldname in fieldnames])+1
         max_length = (self.long_field_names and 64) or 32
         if length > max_length:
             raise ValueError(
-                "Field names are restricted to %d characters in Matlab"%(max_length-1))
+                "Field names are restricted to %d characters"
+                 % (max_length-1))
         self.write_element(np.array([length], dtype='i4'))
         self.write_element(
             np.array(fieldnames, dtype='S%d'%(length)),
             mdtype=miINT8)
         A = np.atleast_2d(self.arr).flatten('F')
-        MWG = Mat5WriterGetter(self.file_stream,
-                               self.unicode_strings)
+        MWG = self.make_writer_getter()
         for el in A:
             for f in fieldnames:
                 MW = MWG.matrix_writer_factory(el[f], '')
@@ -836,11 +849,17 @@
 
 
 class Mat5ObjectWriter(Mat5StructWriter):
+    ''' class to write matlab objects
+
+    Same as writing structs, except different mx class, and extra
+    classname element after header
+    '''
+    default_mclass = mxOBJECT_CLASS
     def write(self):
-        self.write_header(mclass=mxOBJECT_CLASS)
+        self.write_header()
         self.write_element(np.array(self.arr.classname, dtype='S'),
                            mdtype=miINT8)
-        self.write_fields()
+        self._write_items()
 
 
 class Mat5WriterGetter(object):
@@ -871,7 +890,7 @@
         # Next try and convert to an array
         narr = np.asanyarray(arr)
         if narr.dtype.type in (np.object, np.object_) and \
-           narr.size == 1 and narr == arr:
+           narr.shape == () and narr == arr:
             # No interesting conversion possible
             raise TypeError('Could not convert %s (type %s) to array'
                             % (arr, type(arr)))
@@ -926,7 +945,7 @@
         file_stream.write(hdr.tostring())
 
     def get_unicode_strings(self):
-        return self.write_getter.unicode_strings
+        return self.writer_getter.unicode_strings
     def set_unicode_strings(self, unicode_strings):
         self.writer_getter.unicode_strings = unicode_strings
     unicode_strings = property(get_unicode_strings,
@@ -935,7 +954,7 @@
                                'get/set unicode strings property')
 
     def get_long_field_names(self):
-        return self.write_getter.long_field_names
+        return self.writer_getter.long_field_names
     def set_long_field_names(self, long_field_names):
         self.writer_getter.long_field_names = long_field_names
     long_field_names = property(get_long_field_names,

Modified: trunk/scipy/io/matlab/tests/test_mio.py
===================================================================
--- trunk/scipy/io/matlab/tests/test_mio.py	2008-12-13 21:19:02 UTC (rev 5245)
+++ trunk/scipy/io/matlab/tests/test_mio.py	2008-12-14 02:30:09 UTC (rev 5246)
@@ -25,7 +25,7 @@
 import scipy.sparse as SP
 
 from scipy.io.matlab.mio import loadmat, savemat, find_mat_file
-from scipy.io.matlab.mio5 import MatlabObject
+from scipy.io.matlab.mio5 import MatlabObject, MatFile5Writer
 
 test_data_path = join(dirname(__file__), 'data')
 
@@ -196,7 +196,7 @@
 def _check_level(label, expected, actual):
     """ Check one level of a potentially nested array """
     if SP.issparse(expected): # allow different types of sparse matrices
-        assert SP.issparse(actual)
+        assert_true(SP.issparse(actual))
         assert_array_almost_equal(actual.todense(),
                                   expected.todense(),
                                   err_msg = label,
@@ -205,8 +205,8 @@
     # Check types are as expected
     typex = type(expected)
     typac = type(actual)
-    assert typex is typac, \
-           "Expected type %s, got %s at %s" % (typex, typac, label)
+    assert_true(typex is typac, \
+           "Expected type %s, got %s at %s" % (typex, typac, label))
     # A field in a record array may not be an ndarray
     # A scalar from a record array will be type np.void
     if not isinstance(expected,
@@ -246,7 +246,7 @@
         label = "test %s; file %s" % (name, file_name)
         for k, expected in case.items():
             k_label = "%s, variable %s" % (label, k)
-            assert k in matdict, "Missing key at %s" % k_label
+            assert_true(k in matdict, "Missing key at %s" % k_label)
             _check_level(k_label, expected, matdict[k])
 
 # Round trip tests
@@ -264,7 +264,8 @@
         expected = case['expected']
         filt = join(test_data_path, 'test%s_*.mat' % name)
         files = glob(filt)
-        assert files, "No files for test %s using filter %s" % (name, filt)
+        assert_true(len(files) > 0,
+                    "No files for test %s using filter %s" % (name, filt))
         yield _load_check_case, name, files, expected
 
 
@@ -308,7 +309,7 @@
     # Check any hdf5 files raise an error
     filenames = glob(
         join(test_data_path, 'testhdf5*.mat'))
-    assert len(filenames)
+    assert_true(len(filenames)>0)
     for filename in filenames:
         assert_raises(NotImplementedError,
                       loadmat,
@@ -370,3 +371,49 @@
     st1 = np.zeros((1,1), dtype=[(fldname, object)])
     assert_raises(ValueError, savemat, StringIO(),
                   {'longstruct': st1}, format='5',long_field_names=True)
+
+
+def test_long_field_names_in_struct():
+    # Regression test - long_field_names was erased if you passed a struct
+    # within a struct
+    lim = 63
+    fldname = 'a' * lim
+    cell = np.ndarray((1,2),dtype=object)
+    st1 = np.zeros((1,1), dtype=[(fldname, object)])
+    cell[0,0]=st1
+    cell[0,1]=st1
+    mat_stream = StringIO()
+    savemat(StringIO(), {'longstruct': cell}, format='5',long_field_names=True)
+    #
+    # Check to make sure it fails with long field names off
+    #
+    assert_raises(ValueError, savemat, StringIO(),
+                  {'longstruct': cell}, format='5', long_field_names=False)
+
+def test_cell_with_one_thing_in_it():
+    # Regression test - make a cell array that's 1 x 2 and put two
+    # strings in it.  It works. Make a cell array that's 1 x 1 and put
+    # a string in it. It should work but, in the old days, it didn't.
+    cells = np.ndarray((1,2),dtype=object)
+    cells[0,0]='Hello'
+    cells[0,1]='World'
+    mat_stream = StringIO()
+    savemat(StringIO(), {'x': cells}, format='5')
+
+    cells = np.ndarray((1,1),dtype=object)
+    cells[0,0]='Hello, world'
+    mat_stream = StringIO()
+    savemat(StringIO(), {'x': cells}, format='5')
+
+def test_writer_properties():
+    # Tests getting, setting of properties of matrix writer
+    mfw = MatFile5Writer(StringIO())
+    yield assert_equal, mfw.global_vars, []
+    mfw.global_vars = ['avar']
+    yield assert_equal, mfw.global_vars, ['avar']
+    yield assert_equal, mfw.unicode_strings, False
+    mfw.unicode_strings = True
+    yield assert_equal, mfw.unicode_strings, True
+    yield assert_equal, mfw.long_field_names, False
+    mfw.long_field_names = True
+    yield assert_equal, mfw.long_field_names, True




More information about the Scipy-svn mailing list