[Scipy-svn] r2212 - in trunk/Lib/io: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Sep 22 08:20:34 EDT 2006


Author: matthew.brett at gmail.com
Date: 2006-09-22 07:20:28 -0500 (Fri, 22 Sep 2006)
New Revision: 2212

Modified:
   trunk/Lib/io/mio4.py
   trunk/Lib/io/mio5.py
   trunk/Lib/io/miobase.py
   trunk/Lib/io/tests/gen_unittests.m
   trunk/Lib/io/tests/test_mio.py
Log:
More slight optimizations, bug fix for empty matrices, unit test for empty matrices

Modified: trunk/Lib/io/mio4.py
===================================================================
--- trunk/Lib/io/mio4.py	2006-09-21 22:26:48 UTC (rev 2211)
+++ trunk/Lib/io/mio4.py	2006-09-22 12:20:28 UTC (rev 2212)
@@ -52,50 +52,16 @@
     4: 'Cray', #!!
     }
 
-class Mat4Header(object):
-    ''' Place holder for Mat4 header
 
-        Defines:
-        next_position - start position of next matrix
-        name
-        dims - shape of matrix as stored (see sparse reader)
-        dtype - numpy dtype of matrix
-        mclass - matlab (TM) code for class of matrix
-        is_char    - True if these are char data
-        is_numeric - True if these are numeric data
-        is_complex - True if data are complex
-        original_dtype - data type in matlab (TM) workspace
-    '''
-    def __init__(self):
-        self.next_position = None
-        self.name = ''
-        self.dims = ()
-        self.dtype = None
-        self.mclass = None
-        self.is_char = None
-        self.is_numeric = None
-        self.is_complex = None
-        self.original_dtype = None
-        
-
 class Mat4ArrayReader(MatArrayReader):
     ''' Class for reading Mat4 arrays
     '''
     
-    def __init__(self, *args, **kwargs):
-        super(Mat4ArrayReader,self).__init__(*args, **kwargs)
-        self._getter_classes = {
-            mxFULL_CLASS: Mat4FullGetter,
-            mxCHAR_CLASS: Mat4CharGetter,
-            mxSPARSE_CLASS: Mat4SparseGetter,
-            }
-        
-    def read_header(self):
-        ''' Read and return Mat4 matrix header
-        '''
-        header = Mat4Header()
+    def matrix_getter_factory(self):
+        ''' Read header, return matrix getter '''
         data = self.read_dtype(self.dtypes['header'])
-        header.name = self.read_ztstring(data['namlen'])
+        header = {}
+        header['name'] = self.read_ztstring(data['namlen'])
         if data['mopt'] < 0 or  data['mopt'] > 5000:
             ValueError, 'Mat 4 mopt wrong format, byteswapping problem?'
         M,rest = divmod(data['mopt'], 1000)
@@ -104,21 +70,24 @@
         T = rest
         if O != 0:
             raise ValueError, 'O in MOPT integer should be 0, wrong format?'
-        header.dtype = self.dtypes[P]
-        header.mclass = T
-        header.dims = (data['mrows'], data['ncols'])
-        header.is_complex = data['imagf'] == 1
-        remaining_bytes = header.dtype.itemsize * product(header.dims)
-        if header.is_complex and not header.mclass == mxSPARSE_CLASS:
+        header['dtype'] = self.dtypes[P]
+        header['mclass'] = T
+        header['dims'] = (data['mrows'], data['ncols'])
+        header['is_complex'] = data['imagf'] == 1
+        remaining_bytes = header['dtype'].itemsize * product(header['dims'])
+        if header['is_complex'] and not header['mclass'] == mxSPARSE_CLASS:
             remaining_bytes *= 2
-        header.next_position = self.mat_stream.tell() + remaining_bytes
-        return header
+        header['next_position'] = self.mat_stream.tell() + remaining_bytes
+        if T == mxFULL_CLASS:
+            return Mat4FullGetter(self, header)
+        elif T == mxCHAR_CLASS:
+            return Mat4CharGetter(self, header)
+        elif T == mxSPARSE_CLASS:
+            return Mat4SparseGetter(self, header)
+        else:
+            raise TypeError, 'No reader for class code %s' % T
 
-    def matrix_getter_factory(self):
-        header = self.read_header()
-        return self._getter_classes[header.mclass](self, header)
 
-
 class Mat4MatrixGetter(MatMatrixGetter):
 
     # Mat4 variables never global or logical
@@ -131,11 +100,12 @@
         (buffer is usually read only)
         a_dtype is assumed to be correct endianness
         '''
-        dt = self.header.dtype
+        dt = self.header['dtype']
+        dims = self.header['dims']
         num_bytes = dt.itemsize
-        for d in self.dims:
+        for d in dims:
             num_bytes *= d
-        arr = ndarray(shape=self.dims,
+        arr = ndarray(shape=dims,
                       dtype=dt,
                       buffer=self.mat_stream.read(num_bytes),
                       order='F')
@@ -145,26 +115,28 @@
 
 
 class Mat4FullGetter(Mat4MatrixGetter):
+    def __init__(self, array_reader, header):
+        super(Mat4FullGetter, self).__init__(array_reader, header)
+        if header['is_complex']:
+            self.mat_dtype = dtype(complex128)
+        else:
+            self.mat_dtype = dtype(float64)
+        
     def get_raw_array(self):
-        self.header.is_numeric = True
-        if self.header.is_complex:
-            self.header.original_dtype = dtype(complex128)
+        if self.header['is_complex']:
             # avoid array copy to save memory
             res = self.read_array(copy=False)
             res_j = self.read_array(copy=False)
             return res + (res_j * 1j)
-        else:
-            self.header.original_dtype = dtype(float64)
-            return self.read_array()
+        return self.read_array()
 
 
 class Mat4CharGetter(Mat4MatrixGetter):
     def get_raw_array(self):
-        self.header.is_char = True
         arr = self.read_array().astype(uint8)
         # ascii to unicode
         S = arr.tostring().decode('ascii')
-        return ndarray(shape=self.dims,
+        return ndarray(shape=self.header['dims'],
                        dtype=dtype('U1'),
                        buffer = array(S)).copy()
 
@@ -187,14 +159,12 @@
     is only detectable because there are 4 storage columns
     '''
     def get_raw_array(self):
-        self.header.original_dtype = dtype(float64)
         res = self.read_array()
         tmp = res[:-1,:]
         dims = res[-1,0:2]
         ij = transpose(tmp[:,0:2]) - 1 # for 1-based indexing
         vals = tmp[:,2]
         if res.shape[1] == 4:
-            self.header.is_complex = True
             vals = vals + res[:-1,3] * 1j
         if have_sparse:
             return scipy.sparse.csc_matrix((vals,ij), dims)

Modified: trunk/Lib/io/mio5.py
===================================================================
--- trunk/Lib/io/mio5.py	2006-09-21 22:26:48 UTC (rev 2211)
+++ trunk/Lib/io/mio5.py	2006-09-22 12:20:28 UTC (rev 2212)
@@ -132,35 +132,6 @@
     ''' Placeholder for holding read data from objects '''
     pass
 
-class Mat5Header(object):
-    ''' Placeholder for Mat5 header
-
-    Defines:
-    next_position - start position of next matrix
-    name
-    dtype - numpy dtype of matrix
-    mclass - matlab (TM) code for class of matrix
-    dims - shape of matrix as stored (see sparse reader)
-    is_complex - True if data are complex
-    is_char    - True if these are char data
-    is_global  - is a global variable in matlab (TM) workspace
-    is_numeric - is basic numeric matrix
-    original_dtype - data type when saved from matlab (TM)
-    '''
-    def __init__(self):
-        self.next_position = None
-        self.is_empty = False
-        self.is_complex = False
-        self.is_global = False
-        self.is_logical = False
-        self.mclass = 0
-        self.is_numeric = None
-        self.original_dtype = None
-        self.is_char = None
-        self.dims = ()
-        self.name = ''
-
-
 class Mat5ArrayReader(MatArrayReader):
     ''' Class to get Mat5 arrays
 
@@ -210,26 +181,10 @@
             if copy:
                 el = el.copy()
         mod8 = byte_count % 8
-        skip = mod8 and 8 - mod8
-        if skip:
-            self.mat_stream.seek(skip, 1)
+        if mod8:
+            self.mat_stream.seek(8 - mod8, 1)
         return el
 
-    def read_header(self):
-        ''' Read header from Mat5 matrix
-        '''
-        header = Mat5Header()
-        af = self.read_dtype(self.dtypes['array_flags'])
-        flags_class = af['flags_class']
-        header.mclass = flags_class & 0xFF
-        header.is_logical = flags_class >> 9 & 1
-        header.is_global = flags_class >> 10 & 1
-        header.is_complex = flags_class >> 11 & 1
-        header.nzmax = af['nzmax']
-        header.dims = self.read_element()
-        header.name = self.read_element().tostring()
-        return header
-    
     def matrix_getter_factory(self):
         ''' Returns reader for next matrix '''
         tag = self.read_dtype(self.dtypes['tag_full'])
@@ -243,13 +198,22 @@
         return self.getter_from_bytes(byte_count)
 
     def getter_from_bytes(self, byte_count):
+        ''' Return matrix getter for current stream position '''
         # Apparently an empty miMATRIX can contain no bytes
         if not byte_count:
-            return Mat5EmptyMatrixGetter(self, header)
-        next_pos = self.mat_stream.tell() + byte_count
-        header = self.read_header()
-        header.next_position = next_pos
-        mc = header.mclass
+            return Mat5EmptyMatrixGetter(self)
+        af = self.read_dtype(self.dtypes['array_flags'])
+        header = {}
+        flags_class = af['flags_class']
+        header['next_position'] = self.mat_stream.tell() + byte_count
+        mc = flags_class & 0xFF
+        header['mclass'] = mc
+        header['is_logical'] = flags_class >> 9 & 1
+        header['is_global'] = flags_class >> 10 & 1
+        header['is_complex'] = flags_class >> 11 & 1
+        header['nzmax'] = af['nzmax']
+        header['dims'] = self.read_element()
+        header['name'] = self.read_element().tostring()
         if mc in mx_numbers:
             return Mat5NumericMatrixGetter(self, header)
         if mc == mxSPARSE_CLASS:
@@ -292,7 +256,7 @@
         in the main stream, not the compressed stream.
         '''
         getter = super(Mat5ZArrayReader, self).getter_from_bytes(byte_count)
-        getter.next_position = self._next_position
+        getter.header['next_position'] = self._next_position
         return getter
     
 
@@ -303,34 +267,49 @@
     '''
     
     def __init__(self, array_reader, header):
-        ''' Accepts @array_reader and @header '''
         super(Mat5MatrixGetter, self).__init__(array_reader, header)
         self.class_dtypes = array_reader.class_dtypes
         self.codecs = array_reader.codecs
-        self.is_global = header.is_global
+        self.is_global = header['is_global']
+        self.mat_dtype = None
 
     def read_element(self, *args, **kwargs):
         return self.array_reader.read_element(*args, **kwargs)
     
 
 class Mat5EmptyMatrixGetter(Mat5MatrixGetter):
-    ''' Dummy class to return empty array for empty matrix '''
+    ''' Dummy class to return empty array for empty matrix
+    '''
+    def __init__(self, array_reader):
+        self.array_reader = array_reader
+        self.mat_stream = array_reader.mat_stream
+        self.data_position = self.mat_stream.tell()
+        self.header = {}
+        self.is_global = False
+        self.mat_dtype = 'f8'
+    
     def get_raw_array(self):
         return array([[]])
 
 
 class Mat5NumericMatrixGetter(Mat5MatrixGetter):
+
+    def __init__(self, array_reader, header):
+        super(Mat5NumericMatrixGetter, self).__init__(array_reader, header)
+        if header['is_logical']:
+            self.mat_dtype = dtype('bool')
+        else:
+            self.mat_dtype = self.class_dtypes[header['mclass']]
+
     def get_raw_array(self):
-        self.header.is_numeric = True
-        self.header.original_dtype = self.class_dtypes[self.header.mclass]
-        if self.header.is_complex:
+        if self.header['is_complex']:
             # avoid array copy to save memory
             res = self.read_element(copy=False)
             res_j = self.read_element(copy=False)
             res = res + (res_j * 1j)
         else:
             res = self.read_element()
-        return ndarray(shape=self.dims,
+        return ndarray(shape=self.header['dims'],
                        dtype=res.dtype,
                        buffer=res,
                        order='F')
@@ -340,7 +319,7 @@
     def get_raw_array(self):
         rowind  = self.read_element()
         colind = self.read_element()
-        if self.header.is_complex:
+        if self.header['is_complex']:
             # avoid array copy to save memory
             res = self.read_element(copy=False)
             res_j = self.read_element(copy=False)
@@ -369,7 +348,7 @@
         ij = vstack((rowind[:len(res)], cols))
         if have_sparse:
             result = scipy.sparse.csc_matrix((res,ij),
-                                             self.dims)
+                                             self.header['dims'])
         else:
             result = (dims, ij, res)
         return result
@@ -377,7 +356,6 @@
 
 class Mat5CharMatrixGetter(Mat5MatrixGetter):
     def get_raw_array(self):
-        self.header.is_char = True
         res = self.read_element()
         # Convert non-string types to unicode
         if isinstance(res, ndarray):
@@ -390,7 +368,7 @@
             else:
                 raise TypeError, 'Did not expect type %s' % res.dtype
             res = res.tostring().decode(codec)
-        return ndarray(shape=self.dims,
+        return ndarray(shape=self.header['dims'],
                        dtype=dtype('U1'),
                        buffer=array(res),
                        order='F').copy()
@@ -399,8 +377,8 @@
 class Mat5CellMatrixGetter(Mat5MatrixGetter):
     def get_raw_array(self):
         # Account for fortran indexing of cells
-        tupdims = tuple(self.dims[::-1])
-        length = product(self.dims)
+        tupdims = tuple(self.header['dims'][::-1])
+        length = product(tupdims)
         result = empty(length, dtype=object)
         for i in range(length):
             result[i] = self.get_item()

Modified: trunk/Lib/io/miobase.py
===================================================================
--- trunk/Lib/io/miobase.py	2006-09-21 22:26:48 UTC (rev 2211)
+++ trunk/Lib/io/miobase.py	2006-09-22 12:20:28 UTC (rev 2212)
@@ -15,6 +15,13 @@
     have_sparse = 0
 
 
+def small_product(arr):
+    ''' Faster than product for small arrays '''
+    res = 1
+    for e in arr:
+        res *= e
+    return res
+
 class ByteOrder(object):
     ''' Namespace for byte ordering '''
     little_endian = sys.byteorder == 'little'
@@ -170,17 +177,26 @@
         occur as submatrices - in cell arrays, structs and objects -
         so we will not see these in the main variable getting routine
         here.
+
+        The read array is the first argument.
+        The getter, passed as second argument to the function, must
+        define properties, iff matlab_compatible option is True:
+        
+        mat_dtype    - data type when loaded into matlab (tm)
+                       (None for no conversion)
+
+        func returns the processed array
         '''
         
-        def func(arr, header):
-            if header.is_char and self.chars_as_strings:
+        def func(arr, getter):
+            if arr.dtype.kind == 'U' and self.chars_as_strings:
                 # Convert char array to string or array of strings
                 dims = arr.shape
                 if len(dims) >= 2: # return array of strings
                     dtt = self.order_code + 'U'
                     n_dims = dims[:-1]
                     str_arr = reshape(arr,
-                                    (product(n_dims),
+                                    (small_product(n_dims),
                                      dims[-1]))
                     arr = empty(n_dims, dtype=object)
                     for i in range(0, n_dims[-1]):
@@ -190,22 +206,20 @@
             if self.matlab_compatible:
                 # Apply options to replicate matlab's (TM)
                 # load into workspace
-                if header.is_logical:
-                    arr = arr.astype(bool)
-                elif header.is_numeric:
-                    # Cast as original matlab (TM) type
-                    if header.original_dtype:
-                        arr = arr.astype(header.original_dtype)
+                if getter.mat_dtype:
+                    arr = arr.astype(getter.mat_dtype)
             if self.squeeze_me:
                 arr = squeeze(arr)
-                if not arr.shape: # 0d coverted to scalar
+                if not arr.size:
+                    arr = array([])
+                elif not arr.shape: # 0d coverted to scalar
                     arr = arr.item()
             return arr
         return func
 
     def chars_to_str(self, str_arr):
         ''' Convert string array to string '''
-        dt = dtype('U' + str(product(str_arr.shape)))
+        dt = dtype('U' + str(small_product(str_arr.shape)))
         return ndarray(shape=(),
                        dtype = dt,
                        buffer = str_arr.copy()).item()
@@ -254,7 +268,7 @@
 
     Accepts
     @array_reader - array reading object (see below)
-    @header       - header for matrix being read
+    @header       - header dictionary for matrix being read
     """
     
     def __init__(self, array_reader, header):
@@ -262,9 +276,7 @@
         self.array_reader = array_reader
         self.dtypes = array_reader.dtypes
         self.header = header
-        self.name = header.name
-        self.next_position = header.next_position
-        self.dims = header.dims
+        self.name = header['name']
         self.data_position = self.mat_stream.tell()
         
     def get_array(self):
@@ -272,13 +284,13 @@
         if not self.mat_stream.tell() == self.data_position:
             self.mat_stream.seek(self.data_position)
         arr = self.get_raw_array()
-        return self.array_reader.processor_func(arr, self.header)
+        return self.array_reader.processor_func(arr, self)
 
     def get_raw_array(self):
         assert False, 'Not implemented'
 
     def to_next(self):
-        self.mat_stream.seek(self.next_position)
+        self.mat_stream.seek(self.header['next_position'])
 
 
 class MatArrayReader(MatStreamAgent):

Modified: trunk/Lib/io/tests/gen_unittests.m
===================================================================
--- trunk/Lib/io/tests/gen_unittests.m	2006-09-21 22:26:48 UTC (rev 2211)
+++ trunk/Lib/io/tests/gen_unittests.m	2006-09-22 12:20:28 UTC (rev 2212)
@@ -47,6 +47,7 @@
 % Two variables in same file
 save([FILEPREFIX 'testmulti' FILESUFFIX], 'a', 'theta')
 
+
 % struct
 save_test('teststruct', ...
 	  struct('stringfield','Rats live on no evil star.',...
@@ -58,6 +59,9 @@
 	  {['This cell contains this string and 3 arrays of increasing' ...
 	    ' length'], 1., 1.:2., 1.:3.});
 
+% Empty cells in two cell matrices
+save_test('testemptycell', {1, 2, [], [], 3});
+
 % 3D matrix
 save_test('test3dmatrix', reshape(1:24,[2 3 4]))
 

Modified: trunk/Lib/io/tests/test_mio.py
===================================================================
--- trunk/Lib/io/tests/test_mio.py	2006-09-21 22:26:48 UTC (rev 2211)
+++ trunk/Lib/io/tests/test_mio.py	2006-09-22 12:20:28 UTC (rev 2212)
@@ -151,6 +151,12 @@
                             dtype=object)}
          }]
     case_table5.append(
+        {'name': 'emptycell',
+         'expected': {'testemptycell':
+                      array([array(1), array(2), array([]),
+                             array([]), array(3)], dtype=object)}
+         })
+    case_table5.append(
         {'name': 'stringarray',
          'expected': {'teststringarray': array(
         [u'one  ', u'two  ', u'three'], dtype=object)},




More information about the Scipy-svn mailing list