[Scipy-svn] r3435 - in trunk/scipy/io: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Oct 13 14:33:37 EDT 2007


Author: matthew.brett at gmail.com
Date: 2007-10-13 13:33:22 -0500 (Sat, 13 Oct 2007)
New Revision: 3435

Modified:
   trunk/scipy/io/mio.py
   trunk/scipy/io/mio5.py
   trunk/scipy/io/tests/test_mio.py
Log:
Basic matlab 5 support - thanks to Lars Voxen Hansen

Modified: trunk/scipy/io/mio.py
===================================================================
--- trunk/scipy/io/mio.py	2007-10-12 17:16:54 UTC (rev 3434)
+++ trunk/scipy/io/mio.py	2007-10-13 18:33:22 UTC (rev 3435)
@@ -27,7 +27,7 @@
     else:
         full_name = None
         junk, file_name = os.path.split(file_name)
-        for path in sys.path:
+        for path in [os.curdir] + list(sys.path):
             test_name = os.path.join(path, file_name)
             if appendmat:
                 test_name += ".mat"
@@ -100,13 +100,14 @@
         mdict = matfile_dict
     return mdict
 
-def savemat(file_name, mdict, appendmat=True):
+def savemat(file_name, mdict, appendmat=True, format='4'):
     """Save a dictionary of names and arrays into the MATLAB-style .mat file.
 
     This saves the arrayobjects in the given dictionary to a matlab
-    Version 4 style .mat file.
+    style .mat file.
     
-    @appendmat  - if true, appends '.mat' extension to filename, if not present
+    appendmat  - if true, appends '.mat' extension to filename, if not present
+    format     - '4' for matlab 4 mat files, '5' for matlab 5 onwards
     """
     file_is_string = isinstance(file_name, basestring)
     if file_is_string:
@@ -121,7 +122,12 @@
                            'file-like object'
         file_stream = file_name
         
-    MW = MatFile4Writer(file_stream)
+    if format == '4':
+        MW = MatFile4Writer(file_stream)
+    elif format == '5':
+        MW = MatFile5Writer(file_stream)
+    else:
+        raise ValueError, 'Format should be 4 or 5'
     MW.put_variables(mdict)
     if file_is_string:
         file_stream.close()

Modified: trunk/scipy/io/mio5.py
===================================================================
--- trunk/scipy/io/mio5.py	2007-10-12 17:16:54 UTC (rev 3434)
+++ trunk/scipy/io/mio5.py	2007-10-13 18:33:22 UTC (rev 3435)
@@ -105,6 +105,41 @@
     mxDOUBLE_CLASS: 'f8',
     }
 
+
+np_to_mtypes = {
+    'f8': miDOUBLE,
+    'c32': miDOUBLE,    
+    'c24': miDOUBLE,
+    'c16': miDOUBLE,
+    'f4': miSINGLE,
+    'c8': miSINGLE,
+    'i1': miINT8,
+    'i2': miINT16,
+    'i4': miINT32,
+    'u1': miUINT8,
+    'u4': miUINT32,
+    'u2': miUINT16,
+    'S1': miUINT8,
+    'U1': miUTF16,
+    }
+
+
+np_to_mxtypes = {
+    'f8': mxDOUBLE_CLASS,
+    'c32': mxDOUBLE_CLASS,    
+    'c24': mxDOUBLE_CLASS,
+    'c16': mxDOUBLE_CLASS,
+    'f4': mxSINGLE_CLASS,
+    'c8': mxSINGLE_CLASS,
+    'i4': mxINT32_CLASS,
+    'i2': mxINT16_CLASS,
+    'u2': mxUINT16_CLASS,
+    'u1': mxUINT8_CLASS,
+    'S1': mxUINT8_CLASS,
+    }
+
+
+
 ''' Before release v7.1 (release 14) matlab (TM) used the system
 default character encoding scheme padded out to 16-bits. Release 14
 and later use Unicode. When saving character data, R14 checks if it
@@ -532,12 +567,22 @@
         self.is_global = is_global
 
     def write_dtype(self, arr):
-        self.file_stream.write(arr.tostring)
+        self.file_stream.write(arr.tostring())
 
-    def write_element(self, arr):
-        # check if small element works - do it
+    def write_element(self, arr, mdtype=None):
         # write tag, data
-        pass
+        tag = N.zeros((), mdtypes_template['tag_full'])
+        if mdtype is None:
+            tag['mdtype'] = np_to_mtypes[arr.dtype.str[1:]]
+        else:
+            tag['mdtype'] = mdtype
+        tag['byte_count'] =  arr.size*arr.itemsize
+        self.write_dtype(tag)
+        self.write_bytes(arr)
+        # do 8 byte padding if needed
+        if tag['byte_count']%8 != 0:
+            pad = (1+tag['byte_count']//8)*8 - tag['byte_count']
+            self.write_bytes(N.zeros((pad,),dtype='u1'))
 
     def write_header(self, mclass,
                      is_global=False,
@@ -561,8 +606,11 @@
         af['flags_class'] = mclass | flags << 8
         af['nzmax'] = nzmax
         self.write_dtype(af)
+        # write array shape
+        self.arr=N.atleast_2d(self.arr)
         self.write_element(N.array(self.arr.shape, dtype='i4'))
-        self.write_element(self.name)
+        # write name
+        self.write_element(N.ndarray(shape=len(self.name), dtype='S1', buffer=self.name))
 
     def update_matrix_tag(self):
         curr_pos = self.file_stream.tell()
@@ -578,34 +626,43 @@
 class Mat5NumericWriter(Mat5MatrixWriter):
 
     def write(self):
-        # identify matlab type for array
-        # make at least 2d
-        # maybe downcast array to smaller matlab type
-        # write real
-        # write imaginary
-        # put padded length in miMATRIX tag
-        pass
-    
+        imagf = self.arr.dtype.kind == 'c'
+        try:
+            mclass = np_to_mxtypes[self.arr.dtype.str[1:]]
+        except KeyError:
+            if imagf:
+                self.arr = self.arr.astype('c128')
+            else:
+                self.arr = self.arr.astype('f8')
+            mclass = mxDOUBLE_CLASS
+        self.write_header(mclass=mclass,is_complex=imagf)
+        if imagf:
+            self.write_element(self.arr.real)
+            self.write_element(self.arr.imag)
+        else:
+            self.write_element(self.arr)
+        self.update_matrix_tag()
 
 class Mat5CharWriter(Mat5MatrixWriter):
-
+    codec='ascii'
     def write(self):
         self.arr_to_chars()
-        self.arr_to_2d()
-        dims = self.arr.shape
-        self.write_header(P=miUINT8,
-                          T=mxCHAR_CLASS)
+        self.write_header(mclass=mxCHAR_CLASS)
         if self.arr.dtype.kind == 'U':
-            # Recode unicode to ascii
-            n_chars = N.product(dims)
+            # Recode unicode using self.codec
+            n_chars = N.product(self.arr.shape)
             st_arr = N.ndarray(shape=(),
                              dtype=self.arr_dtype_number(n_chars),
                              buffer=self.arr)
-            st = st_arr.item().encode('ascii')
-            self.arr = N.ndarray(shape=dims, dtype='S1', buffer=st)
-        self.write_bytes(self.arr)
+            st = st_arr.item().encode(self.codec)
+            self.arr = N.ndarray(shape=(len(st)), dtype='u1', buffer=st)
+        self.write_element(self.arr,mdtype=miUTF8)
+        self.update_matrix_tag()
 
+class Mat5UniCharWriter(Mat5CharWriter):
+    codec='UTF8'
 
+
 class Mat5SparseWriter(Mat5MatrixWriter):
 
     def write(self):
@@ -650,7 +707,7 @@
                 return Mat5SparseWriter(self.stream, arr, name, is_global)
         arr = N.array(arr)
         if arr.dtype.hasobject:
-            types, arr_type = classify_mobjects(arr)
+            types, arr_type = self.classify_mobjects(arr)
             if arr_type == 'c':
                 return Mat5CellWriter(self.stream, arr, name, is_global, types)
             elif arr_type == 's':
@@ -658,10 +715,10 @@
             elif arr_type == 'o':
                 return Mat5ObjectWriter(self.stream, arr, name, is_global)
         if arr.dtype.kind in ('U', 'S'):
-            if self.unicode_strings:
+           if self.unicode_strings:
                 return Mat5UniCharWriter(self.stream, arr, name, is_global)
-            else:
-                return Mat5IntCharWriter(self.stream, arr, name, is_global)            
+           else:
+                return Mat5CharWriter(self.stream, arr, name, is_global)            
         else:
             return Mat5NumericWriter(self.stream, arr, name, is_global)
                     
@@ -678,12 +735,12 @@
                         s  - struct array
                         o  - object array
         '''
-        N = objarr.size
-        types = N.empty((N,), dtype='S1')
+        n = objarr.size
+        types = N.empty((n,), dtype='S1')
         types[:] = 'i'
         type_set = set()
         flato = objarr.flat
-        for i in range(N):
+        for i in range(n):
             obj = flato[i]
             if isinstance(obj, N.ndarray):
                 types[i] = 'a'
@@ -719,8 +776,16 @@
         else:
             self.global_vars = []
         self.writer_getter = Mat5WriterGetter(
-            StringIO,
+            StringIO(),
             unicode_strings)
+        # write header
+        import os, time
+        hdr =  N.zeros((), mdtypes_template['file_header'])
+        hdr['description']='MATLAB 5.0 MAT-file Platform: %s, Created on: %s' % (
+                            os.name,time.asctime())
+        hdr['version']= 0x0100
+        hdr['endian_test']=N.ndarray(shape=(),dtype='S2',buffer=N.uint16(0x4d49))
+        file_stream.write(hdr.tostring())
 
     def get_unicode_strings(self):
         return self.write_getter.unicode_strings
@@ -740,11 +805,12 @@
                 name,
                 is_global,
                 ).write()
+            stream = self.writer_getter.stream
             if self.do_compression:
-                str = zlib.compress(stream.getvalue())
+                str = zlib.compress(stream.getvalue(stream.tell()))
                 tag = N.empty((), mdtypes_template['tag_full'])
                 tag['mdtype'] = miCOMPRESSED
                 tag['byte_count'] = len(str)
                 self.file_stream.write(tag.tostring() + str)
             else:
-                self.file_stream.write(stream.getvalue())
+                self.file_stream.write(stream.getvalue(stream.tell()))

Modified: trunk/scipy/io/tests/test_mio.py
===================================================================
--- trunk/scipy/io/tests/test_mio.py	2007-10-12 17:16:54 UTC (rev 3434)
+++ trunk/scipy/io/tests/test_mio.py	2007-10-13 18:33:22 UTC (rev 3435)
@@ -13,7 +13,6 @@
 set_package_path()
 from scipy.io.mio import loadmat, savemat
 from scipy.io.mio5 import mat_obj, mat_struct
-from scipy.io.mio4 import MatFile4Writer
 restore_path()
 
 try:  # Python 2.3 support
@@ -87,10 +86,10 @@
         return cc
 
     # Add the round trip tests dynamically, with given parameters
-    def _make_rt_check_case(name, expected):
+    def _make_rt_check_case(name, expected, format):
         def cc(self):
             mat_stream = StringIO()
-            savemat(mat_stream, expected)
+            savemat(mat_stream, expected, format)
             mat_stream.seek(0)
             self._check_case(name, [mat_stream], expected)
         cc.__doc__ = "check loadmat case %s" % name
@@ -226,10 +225,12 @@
         assert files, "No files for test %s using filter %s" % (name, filt)
         exec 'check_%s = _make_check_case(name, files, expected)' % name
     # round trip tests
-    for case in case_table4:
+    for case in case_table4 + case_table5:
         name = case['name'] + '_round_trip'
         expected = case['expected']
-        exec 'check_%s = _make_rt_check_case(name, expected)' % name
+        format = case in case_table4 and '4' or '5'
+        exec 'check_%s = _make_rt_check_case(name, expected, format)' \
+             % name
 
 
 if __name__ == "__main__":




More information about the Scipy-svn mailing list