[Scipy-svn] r3675 - branches/io_new

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Dec 15 23:35:46 EST 2007


Author: brian.hawthorne
Date: 2007-12-15 22:35:42 -0600 (Sat, 15 Dec 2007)
New Revision: 3675

Modified:
   branches/io_new/mmio.py
Log:
cleanup, factoring

Modified: branches/io_new/mmio.py
===================================================================
--- branches/io_new/mmio.py	2007-12-16 04:35:23 UTC (rev 3674)
+++ branches/io_new/mmio.py	2007-12-16 04:35:42 UTC (rev 3675)
@@ -14,8 +14,10 @@
                   empty, concatenate, ones, ascontiguousarray
 from itertools import izip
 
-__all__ = ['mminfo','mmread','mmwrite']
+__all__ = ['mminfo','mmread','mmwrite', 'MMFile']
 
+
+#-------------------------------------------------------------------------------
 def mminfo(source):
     """ Queries the contents of the Matrix Market file 'filename' to
     extract size and storage information.
@@ -29,51 +31,13 @@
       rows,cols  - number of matrix rows and columns
       entries    - number of non-zero entries of a sparse matrix
                    or rows*cols for a dense matrix
-      rep        - 'coordinate' | 'array'
+      format     - 'coordinate' | 'array'
       field      - 'real' | 'complex' | 'pattern' | 'integer'
       symm       - 'general' | 'symmetric' | 'skew-symmetric' | 'hermitian'
     """
-    close_it = 0
-    if type(source) is type(''):
-        if not os.path.isfile(source):
-            if source[-4:] != '.mtx':
-                source = source + '.mtx'
-        source = open(source,'r')
-        close_it = 1
-    line = source.readline().split()
-    if not line[0].startswith('%%MatrixMarket'):
-        raise ValueError,'source is not in Matrix Market format'
+    return MMFile.info(source)
 
-    assert len(line)==5,`line`
-
-    assert line[1].strip().lower()=='matrix',`line`
-
-    rep = line[2].strip().lower()
-    if rep=='dense': rep='array'
-    elif rep=='sparse': rep='coordinate'
-
-    field = line[3].strip().lower()
-
-    symm = line[4].strip().lower()
-
-    while line:
-        line = source.readline()
-        if line.startswith('%'):
-            continue
-        line = line.split()
-        if rep=='array':
-            assert len(line)==2,`line`
-            rows,cols = map(eval,line)
-            entries = rows*cols
-        else:
-            assert len(line)==3,`line`
-            rows,cols,entries = map(eval,line)
-        break
-
-    if close_it:
-        source.close()
-    return (rows,cols,entries,rep,field,symm)
-
+#-------------------------------------------------------------------------------
 def mmread(source):
     """ Reads the contents of a Matrix Market file 'filename' into a matrix.
 
@@ -86,313 +50,525 @@
 
       a         - sparse or full matrix
     """
-    close_it = 0
-    if type(source) is type(''):
-        if not os.path.isfile(source):
-            if os.path.isfile(source+'.mtx'):
-                source = source + '.mtx'
-            elif os.path.isfile(source+'.mtx.gz'):
-                source = source + '.mtx.gz'
-        if source[-3:] == '.gz':
-            import gzip
-            source = gzip.open(source)
-        else:
-            source = open(source,'r')
-        close_it = 1
+    return MMFile().read(source)
 
-    rows,cols,entries,rep,field,symm = mminfo(source)
+#-------------------------------------------------------------------------------
+def mmwrite(target, a, comment='', field=None, precision=None):
+    """ Writes the sparse or dense matrix A to a Matrix Market formatted file.
 
-    try:
-        from scipy.sparse import coo_matrix
-    except ImportError:
-        coo_matrix = None
+    Inputs:
 
-    if field=='integer':
-        dtype='i'
-    elif field=='real':
-        dtype='d'
-    elif field=='complex':
-        dtype='D'
-    elif field=='pattern':
-        dtype='d'
-    else:
-        raise ValueError,`field`
+      target    - Matrix Market filename (extension .mtx) or open file object
+      a         - sparse or full matrix
+      comment   - comments to be prepended to the Matrix Market file
+      field     - 'real' | 'complex' | 'pattern' | 'integer'
+      precision - Number of digits to display for real or complex values.
+    """
+    MMFile().write(target, a, comment, field, precision)
 
-    has_symmetry = symm in ['symmetric','skew-symmetric','hermitian']
-    is_complex = field=='complex'
-    is_skew = symm=='skew-symmetric'
-    is_herm = symm=='hermitian'
-    is_pattern = field=='pattern'
 
-    if rep == 'array':
-        a = zeros((rows,cols),dtype=dtype)
-        line = 1
-        i,j = 0,0
-        while line:
+################################################################################
+class MMFile (object):
+    __slots__ = (
+      '_rows',
+      '_cols',
+      '_entries',
+      '_format',
+      '_field',
+      '_symmetry')
+
+    @property
+    def rows(self): return self._rows
+    @property
+    def cols(self): return self._cols
+    @property
+    def entries(self): return self._entries
+    @property
+    def format(self): return self._format
+    @property
+    def field(self): return self._field
+    @property
+    def symmetry(self): return self._symmetry
+
+    @property
+    def has_symmetry(self):
+        return self._symmetry in (self.SYMMETRY_SYMMETRIC,
+          self.SYMMETRY_SKEW_SYMMETRIC, self.SYMMETRY_HERMITIAN)
+
+    # format values
+    FORMAT_COORDINATE = 'coordinate'
+    FORMAT_ARRAY = 'array'
+    FORMAT_VALUES = (FORMAT_COORDINATE, FORMAT_ARRAY)
+
+    @classmethod
+    def _validate_format(self, format):
+        if format not in self.FORMAT_VALUES:
+            raise ValueError,'unknown format type %s, must be one of %s'% \
+              (`format`, `self.FORMAT_VALUES`)
+
+    # field values
+    FIELD_INTEGER = 'integer'
+    FIELD_REAL = 'real'
+    FIELD_COMPLEX = 'complex'
+    FIELD_PATTERN = 'pattern'
+    FIELD_VALUES = (FIELD_INTEGER, FIELD_REAL, FIELD_COMPLEX, FIELD_PATTERN)
+
+    @classmethod
+    def _validate_field(self, field):
+        if field not in self.FIELD_VALUES:
+            raise ValueError,'unknown field type %s, must be one of %s'% \
+              (`field`, `self.FIELD_VALUES`)
+
+    # symmetry values
+    SYMMETRY_GENERAL = 'general'
+    SYMMETRY_SYMMETRIC = 'symmetric'
+    SYMMETRY_SKEW_SYMMETRIC = 'skew-symmetric'
+    SYMMETRY_HERMITIAN = 'hermitian'
+    SYMMETRY_VALUES = (
+      SYMMETRY_GENERAL, SYMMETRY_SYMMETRIC, SYMMETRY_SKEW_SYMMETRIC,
+      SYMMETRY_HERMITIAN)
+
+    @classmethod
+    def _validate_symmetry(self, symmetry):
+        if symmetry not in self.SYMMETRY_VALUES:
+            raise ValueError,'unknown symmetry type %s, must be one of %s'% \
+              (`symmetry`, `self.SYMMETRY_VALUES`)
+
+    DTYPES_BY_FIELD = {
+      FIELD_INTEGER: 'i',
+      FIELD_REAL: 'd',
+      FIELD_COMPLEX:'D',
+      FIELD_PATTERN:'d'}
+
+    #---------------------------------------------------------------------------
+    @staticmethod
+    def reader(): pass
+
+    #---------------------------------------------------------------------------
+    @staticmethod
+    def writer(): pass
+
+    #---------------------------------------------------------------------------
+    @classmethod
+    def info(self, source):
+        source, close_it = self._open(source)
+
+        try:
+
+            # read and validate header line
             line = source.readline()
-            if not line or line.startswith('%'):
-                continue
-            if is_complex:
-                aij = complex(*map(float,line.split()))
+            mmid, matrix, format, field, symmetry  = \
+              [part.strip().lower() for part in line.split()]
+            if not mmid.startswith('%%matrixmarket'):
+                raise ValueError,'source is not in Matrix Market format'
+
+            assert matrix == 'matrix',`line`
+
+            # ??? Is this necessary?  I don't see 'dense' or 'sparse' in the spec
+            # http://math.nist.gov/MatrixMarket/formats.html
+            if format == 'dense': format = self.FORMAT_ARRAY
+            elif format == 'sparse': format = self.FORMAT_COORDINATE
+
+            # skip comments
+            while line.startswith('%'): line = source.readline()
+
+            line = line.split()
+            if format == self.FORMAT_ARRAY:
+                assert len(line)==2,`line`
+                rows,cols = map(float, line)
+                entries = rows*cols
             else:
-                aij = float(line)
-            a[i,j] = aij
-            if has_symmetry and i!=j:
-                if is_skew:
-                    a[j,i] = -aij
-                elif is_herm:
-                    a[j,i] = conj(aij)
+                assert len(line)==3,`line`
+                rows, cols, entries = map(float, line)
+
+            return (rows, cols, entries, format, field, symmetry)
+
+        finally:
+            if close_it: source.close()
+
+    #---------------------------------------------------------------------------
+    @staticmethod
+    def _open(filespec, mode='r'):
+        """
+        Return an open file stream for reading based on source.  If source is
+        a file name, open it (after trying to find it with mtx and gzipped mtx
+        extensions).  Otherwise, just return source.
+        """
+        close_it = False
+        if type(filespec) is type(''):
+            close_it = True
+
+            # open for reading
+            if mode[0] == 'r':
+
+                # determine filename plus extension
+                if not os.path.isfile(filespec):
+                    if os.path.isfile(filespec+'.mtx'):
+                        filespec = filespec + '.mtx'
+                    elif os.path.isfile(filespec+'.mtx.gz'):
+                        filespec = filespec + '.mtx.gz'
+                # open filename
+                if filespec[-3:] == '.gz':
+                    import gzip
+                    stream = gzip.open(filespec, mode)
                 else:
-                    a[j,i] = aij
-            if i<rows-1:
-                i = i + 1
+                    stream = open(filespec, mode)
+     
+            # open for writing
             else:
-                j = j + 1
-                if not has_symmetry:
-                    i = 0
-                else:
-                    i = j
-        assert i in [0,j] and j==cols,`i,j,rows,cols`
+                if filespec[-4:] != '.mtx':
+                    filespec = filespec + '.mtx'
+                stream = open(filespec, mode)
+        else:
+            stream = filespec
 
-    elif rep=='coordinate' and coo_matrix is None:
-        # Read sparse matrix to dense when coo_matrix is not available.
-        a = zeros((rows,cols), dtype=dtype)
-        line = 1
-        k = 0
-        while line:
-            line = source.readline()
-            if not line or line.startswith('%'):
-                continue
-            l = line.split()
-            i,j = map(int,l[:2])
-            i,j = i-1,j-1
-            if is_complex:
-                aij = complex(*map(float,l[2:]))
-            else:
-                aij = float(l[2])
-            a[i,j] = aij
-            if has_symmetry and i!=j:
-                if is_skew:
-                    a[j,i] = -aij
-                elif is_herm:
-                    a[j,i] = conj(aij)
-                else:
-                    a[j,i] = aij
-            k = k + 1
-        assert k==entries,`k,entries`
+        return stream, close_it
 
-    elif rep=='coordinate':
-        from numpy import fromfile,fromstring
+    #---------------------------------------------------------------------------
+    @staticmethod
+    def _get_symmetry(a):
+        m,n = a.shape
+        if m!=n:
+            return MMFile.SYMMETRY_GENERAL
+        issymm = 1
+        isskew = 1
+        isherm = a.dtype.char in 'FD'
+        for j in range(n):
+            for i in range(j+1,n):
+                aij,aji = a[i][j],a[j][i]
+                if issymm and aij != aji:
+                    issymm = 0
+                if isskew and aij != -aji:
+                    isskew = 0
+                if isherm and aij != conj(aji):
+                    isherm = 0
+                if not (issymm or isskew or isherm):
+                    break
+        if issymm: return MMFile.SYMMETRY_SYMMETRIC
+        if isskew: return MMFile.SYMMETRY_SKEW_SYMMETRIC
+        if isherm: return MMFile.SYMMETRY_HERMITIAN
+        return MMFile.SYMMETRY_GENERAL
+
+    #---------------------------------------------------------------------------
+    @staticmethod
+    def _field_template(field, precision):
+        return {
+          MMFile.FIELD_REAL: '%%.%ie\n' % precision, 
+          MMFile.FIELD_INTEGER: '%i\n',
+          MMFile.FIELD_COMPLEX: '%%.%ie %%.%ie\n' % (precision,precision)
+        }.get(field, None)
+
+    #---------------------------------------------------------------------------
+    def __init__(self, **kwargs): self._init_attrs(**kwargs)
+
+    #---------------------------------------------------------------------------
+    def read(self, source):
+        stream, close_it = self._open(source)
+
         try:
-            # fromfile works for normal files
-            flat_data = fromfile(source,sep=' ')
-        except:
-            # fallback - fromfile fails for some file-like objects
-            flat_data = fromstring(source.read(),sep=' ')
-            
-            # TODO use iterator (e.g. xreadlines) to avoid reading
-            # the whole file into memory
-        
-        if is_pattern:
-            flat_data = flat_data.reshape(-1,2)
-            I = ascontiguousarray(flat_data[:,0], dtype='intc')
-            J = ascontiguousarray(flat_data[:,1], dtype='intc')
-            V = ones(len(I))
-        elif is_complex:
-            flat_data = flat_data.reshape(-1,4)
-            I = ascontiguousarray(flat_data[:,0], dtype='intc')
-            J = ascontiguousarray(flat_data[:,1], dtype='intc')
-            V = ascontiguousarray(flat_data[:,2], dtype='complex')
-            V.imag = flat_data[:,3]
-        else:
-            flat_data = flat_data.reshape(-1,3)
-            I = ascontiguousarray(flat_data[:,0], dtype='intc')
-            J = ascontiguousarray(flat_data[:,1], dtype='intc')
-            V = ascontiguousarray(flat_data[:,2], dtype='float')
+            self._parse_header(stream)
+            return self._parse_body(stream)
 
-        I -= 1 #adjust indices (base 1 -> base 0)
-        J -= 1
+        finally:
+            if close_it: stream.close()
 
-        if has_symmetry:
-            mask = (I != J)       #off diagonal mask
-            od_I = I[mask]
-            od_J = J[mask]
-            od_V = V[mask]
+    #---------------------------------------------------------------------------
+    def write(self, target, a, comment='', field=None, precision=None):
+        stream, close_it = self._open(target, 'w')
 
-            I = concatenate((I,od_J))
-            J = concatenate((J,od_I))
+        try:
+            self._write(stream, a, comment, field, precision)
 
-            if is_skew:
-                od_V *= -1
-            elif is_herm:
-                od_V = od_V.conjugate()
+        finally:
+            if close_it: stream.close()
+            else: stream.flush()
 
-            V = concatenate((V,od_V))
 
-        a = coo_matrix((V, (I, J)), dims=(rows, cols), dtype=dtype)
-    else:
-        raise NotImplementedError,`rep`
+    #---------------------------------------------------------------------------
+    def _init_attrs(self, **kwargs):
+        """
+        Initialize each attributes with the corresponding keyword arg value
+        or a default of None
+        """
+        attrs = self.__class__.__slots__
+        public_attrs = [attr[1:] for attr in attrs]
+        invalid_keys = set(kwargs.keys()) - set(public_attrs)
+       
+        if invalid_keys:
+            raise ValueError, \
+              'found %s invalid keyword arguments, please only use %s' % \
+              (`tuple(invalid_keys)`, `public_attrs`)
 
-    if close_it:
-        source.close()
-    return a
+        for attr in attrs: setattr(self, attr, kwargs.get(attr[1:], None))
 
-def mmwrite(target,a,comment='',field=None,precision=None):
-    """ Writes the sparse or dense matrix A to a Matrix Market formatted file.
+    #---------------------------------------------------------------------------
+    def _parse_header(self, stream):
+        rows, cols, entries, format, field, symmetry = \
+          self.__class__.info(stream)
+        self._init_attrs(rows=rows, cols=cols, entries=entries, format=format,
+          field=field, symmetry=symmetry)
 
-    Inputs:
+    #---------------------------------------------------------------------------
+    def _parse_body(self, stream):
+        rows, cols, entries, format, field, symm = \
+          (self.rows, self.cols, self.entries, self.format, self.field, self.symmetry)
 
-      target    - Matrix Market filename (extension .mtx) or open file object
-      a         - sparse or full matrix
-      comment   - comments to be prepended to the Matrix Market file
-      field     - 'real' | 'complex' | 'pattern' | 'integer'
-      precision - Number of digits to display for real or complex values.
-    """
-    close_it = 0
-    if type(target) is type(''):
-        if target[-4:] != '.mtx':
-            target = target + '.mtx'
-        target = open(target,'w')
-        close_it = 1
+        try:
+            from scipy.sparse import coo_matrix
+        except ImportError:
+            coo_matrix = None
 
-    if isinstance(a, list) or isinstance(a, ndarray) or isinstance(a, tuple) or hasattr(a,'__array__'):
-        rep = 'array'
-        a = asarray(a)
-        if len(a.shape) != 2:
-            raise ValueError, 'expected matrix'
-        rows,cols = a.shape
-        entries = rows*cols
-        typecode = a.dtype.char
-        if field is not None:
-            if field=='integer':
-                a = a.astype('i')
-            elif field=='real':
-                if typecode not in 'fd':
-                    a = a.astype('d')
-            elif field=='complex':
-                if typecode not in 'FD':
-                    a = a.astype('D')
-            elif field=='pattern':
-                pass
+        dtype = self.DTYPES_BY_FIELD.get(field, None)
+
+        has_symmetry = self.has_symmetry
+        is_complex = field == self.FIELD_COMPLEX
+        is_skew = symm == self.SYMMETRY_SKEW_SYMMETRIC
+        is_herm = symm == self.SYMMETRY_HERMITIAN
+        is_pattern = field == self.FIELD_PATTERN
+
+        if format == self.FORMAT_ARRAY:
+            a = zeros((rows,cols),dtype=dtype)
+            line = 1
+            i,j = 0,0
+            while line:
+                line = stream.readline()
+                if not line or line.startswith('%'):
+                    continue
+                if is_complex:
+                    aij = complex(*map(float,line.split()))
+                else:
+                    aij = float(line)
+                a[i,j] = aij
+                if has_symmetry and i!=j:
+                    if is_skew:
+                        a[j,i] = -aij
+                    elif is_herm:
+                        a[j,i] = conj(aij)
+                    else:
+                        a[j,i] = aij
+                if i<rows-1:
+                    i = i + 1
+                else:
+                    j = j + 1
+                    if not has_symmetry:
+                        i = 0
+                    else:
+                        i = j
+            assert i in [0,j] and j==cols,`i,j,rows,cols`
+
+        elif format == self.FORMAT_COORDINATE and coo_matrix is None:
+            # Read sparse matrix to dense when coo_matrix is not available.
+            a = zeros((rows,cols), dtype=dtype)
+            line = 1
+            k = 0
+            while line:
+                line = stream.readline()
+                if not line or line.startswith('%'):
+                    continue
+                l = line.split()
+                i,j = map(int,l[:2])
+                i,j = i-1,j-1
+                if is_complex:
+                    aij = complex(*map(float,l[2:]))
+                else:
+                    aij = float(l[2])
+                a[i,j] = aij
+                if has_symmetry and i!=j:
+                    if is_skew:
+                        a[j,i] = -aij
+                    elif is_herm:
+                        a[j,i] = conj(aij)
+                    else:
+                        a[j,i] = aij
+                k = k + 1
+            assert k==entries,`k,entries`
+
+        elif format == self.FORMAT_COORDINATE:
+            from numpy import fromfile,fromstring
+            try:
+                # fromfile works for normal files
+                flat_data = fromfile(stream,sep=' ')
+            except:
+                # fallback - fromfile fails for some file-like objects
+                flat_data = fromstring(stream.read(),sep=' ')
+                
+                # TODO use iterator (e.g. xreadlines) to avoid reading
+                # the whole file into memory
+            
+            if is_pattern:
+                flat_data = flat_data.reshape(-1,2)
+                I = ascontiguousarray(flat_data[:,0], dtype='intc')
+                J = ascontiguousarray(flat_data[:,1], dtype='intc')
+                V = ones(len(I))
+            elif is_complex:
+                flat_data = flat_data.reshape(-1,4)
+                I = ascontiguousarray(flat_data[:,0], dtype='intc')
+                J = ascontiguousarray(flat_data[:,1], dtype='intc')
+                V = ascontiguousarray(flat_data[:,2], dtype='complex')
+                V.imag = flat_data[:,3]
             else:
-                raise ValueError,'unknown field '+field
-        typecode = a.dtype.char
-    else:
-        rep = 'coordinate'
-        from scipy.sparse import spmatrix
-        if not isinstance(a,spmatrix):
-            raise ValueError,'unknown matrix type ' + `type(a)`
-        rows,cols = a.shape
-        entries = a.getnnz()
-        typecode = a.dtype.char
+                flat_data = flat_data.reshape(-1,3)
+                I = ascontiguousarray(flat_data[:,0], dtype='intc')
+                J = ascontiguousarray(flat_data[:,1], dtype='intc')
+                V = ascontiguousarray(flat_data[:,2], dtype='float')
 
-    if precision is None:
-        if typecode in 'fF':
-            precision = 8
+            I -= 1 #adjust indices (base 1 -> base 0)
+            J -= 1
+
+            if has_symmetry:
+                mask = (I != J)       #off diagonal mask
+                od_I = I[mask]
+                od_J = J[mask]
+                od_V = V[mask]
+
+                I = concatenate((I,od_J))
+                J = concatenate((J,od_I))
+
+                if is_skew:
+                    od_V *= -1
+                elif is_herm:
+                    od_V = od_V.conjugate()
+
+                V = concatenate((V,od_V))
+
+            a = coo_matrix((V, (I, J)), dims=(rows, cols), dtype=dtype)
         else:
-            precision = 16
-    if field is None:
-        if typecode in 'li':
-            field = 'integer'
-        elif typecode in 'df':
-            field = 'real'
-        elif typecode in 'DF':
-            field = 'complex'
-        else:
-            raise TypeError,'unexpected typecode '+typecode
+            raise NotImplementedError,`format`
 
-    if rep == 'array':
-        symm = _get_symmetry(a)
-    else:
-        symm = 'general'
+        return a
 
-    target.write('%%%%MatrixMarket matrix %s %s %s\n' % (rep,field,symm))
+    #---------------------------------------------------------------------------
+    def _write(self, stream, a, comment='', field=None, precision=None):
 
-    for line in comment.split('\n'):
-        target.write('%%%s\n' % (line))
+        if isinstance(a, list) or isinstance(a, ndarray) or isinstance(a, tuple) or hasattr(a,'__array__'):
+            rep = self.FORMAT_ARRAY
+            a = asarray(a)
+            if len(a.shape) != 2:
+                raise ValueError, 'expected matrix'
+            rows,cols = a.shape
+            entries = rows*cols
 
-    if field in ['real','integer']:
-        if field=='real':
-            format = '%%.%ie\n' % precision
+            if field is not None:
+
+                if field == self.FIELD_INTEGER:
+                    a = a.astype('i')
+                elif field == self.FIELD_REAL:
+                    if a.dtype.char not in 'fd':
+                        a = a.astype('d')
+                elif field == self.FIELD_COMPLEX:
+                    if a.dtype.char not in 'FD':
+                        a = a.astype('D')
+
         else:
-            format = '%i\n'
-    elif field=='complex':
-        format = '%%.%ie %%.%ie\n' % (precision,precision)
+            from scipy.sparse import spmatrix
+            if not isinstance(a,spmatrix):
+                raise ValueError,'unknown matrix type ' + `type(a)`
+            rep = 'coordinate'
+            rows, cols = a.shape
+            entries = a.getnnz()
 
-    if rep == 'array':
-        target.write('%i %i\n' % (rows,cols))
-        if field in ['real','integer']:
-            if symm=='general':
-                for j in range(cols):
-                    for i in range(rows):
-                        target.write(format % a[i,j])
+        typecode = a.dtype.char
+
+        if precision is None:
+            if typecode in 'fF':
+                precision = 8
             else:
-                for j in range(cols):
-                    for i in range(j,rows):
-                        target.write(format % a[i,j])
-        elif field=='complex':
-            if symm=='general':
-                for j in range(cols):
-                    for i in range(rows):
-                        aij = a[i,j]
-                        target.write(format % (real(aij),imag(aij)))
+                precision = 16
+
+        if field is None:
+            if typecode in 'li':
+                field = 'integer'
+            elif typecode in 'df':
+                field = 'real'
+            elif typecode in 'DF':
+                field = 'complex'
             else:
-                for j in range(cols):
-                    for i in range(j,rows):
-                        aij = a[i,j]
-                        target.write(format % (real(aij),imag(aij)))
-        elif field=='pattern':
-            raise ValueError,'Pattern type inconsisted with dense matrix'
+                raise TypeError,'unexpected typecode '+typecode
+
+        if rep == self.FORMAT_ARRAY:
+            symm = self._get_symmetry(a)
         else:
-            raise TypeError,'Unknown matrix type '+`field`
-    else:
-        format = '%i %i ' + format
-        target.write('%i %i %i\n' % (rows,cols,entries))
-        assert symm=='general',`symm`
+            symm = self.SYMMETRY_GENERAL
 
-        coo = a.tocoo() # convert to COOrdinate format
-        I,J,V = coo.row + 1, coo.col + 1, coo.data # change base 0 -> base 1
+        # validate rep, field, and symmetry
+        self.__class__._validate_format(rep)
+        self.__class__._validate_field(field)
+        self.__class__._validate_symmetry(symm)
 
-        if field in ['real','integer']:
-            for ijv_tuple in izip(I,J,V):
-                target.writelines(format % ijv_tuple)
-        elif field=='complex':
-            for ijv_tuple in izip(I,J,V.real,V.imag):
-                target.writelines(format % ijv_tuple)
-        elif field=='pattern':
-            raise NotImplementedError,`field`
+        # write initial header line
+        stream.write('%%%%MatrixMarket matrix %s %s %s\n' % (rep,field,symm))
+
+        # write comments
+        for line in comment.split('\n'):
+            stream.write('%%%s\n' % (line))
+
+
+        template = self._field_template(field, precision)
+
+        # write dense format
+        if rep == self.FORMAT_ARRAY:
+
+            # write shape spec
+            stream.write('%i %i\n' % (rows,cols))
+
+            if field in (self.FIELD_INTEGER, self.FIELD_REAL):
+
+                if symm == self.SYMMETRY_GENERAL:
+                    for j in range(cols):
+                        for i in range(rows):
+                            stream.write(template % a[i,j])
+                else:
+                    for j in range(cols):
+                        for i in range(j,rows):
+                            stream.write(template % a[i,j])
+
+            elif field == self.FIELD_COMPLEX:
+
+                if symm == self.SYMMETRY_GENERAL:
+                    for j in range(cols):
+                        for i in range(rows):
+                            aij = a[i,j]
+                            stream.write(template % (real(aij),imag(aij)))
+                else:
+                    for j in range(cols):
+                        for i in range(j,rows):
+                            aij = a[i,j]
+                            stream.write(template % (real(aij),imag(aij)))
+
+            elif field == self.FIELD_PATTERN:
+                raise ValueError,'pattern type inconsisted with dense format'
+
+            else:
+                raise TypeError,'Unknown field type %s'% `field`
+
+        # write sparse format
         else:
-            raise TypeError,'Unknown matrix type '+`field`
 
-    if close_it:
-        target.close()
-    else:
-        target.flush()
-    return
+            if symm != self.SYMMETRY_GENERAL:
+                raise ValueError, 'symmetric matrices incompatible with sparse format'
 
-def _get_symmetry(a):
-    m,n = a.shape
-    if m!=n:
-        return 'general'
-    issymm = 1
-    isskew = 1
-    isherm = a.dtype.char in 'FD'
-    for j in range(n):
-        for i in range(j+1,n):
-            aij,aji = a[i][j],a[j][i]
-            if issymm and aij != aji:
-                issymm = 0
-            if isskew and aij != -aji:
-                isskew = 0
-            if isherm and aij != conj(aji):
-                isherm = 0
-            if not (issymm or isskew or isherm):
-                break
-    if issymm: return 'symmetric'
-    if isskew: return 'skew-symmetric'
-    if isherm: return 'hermitian'
-    return 'general'
+            # write shape spec
+            stream.write('%i %i %i\n' % (rows,cols,entries))
 
+            # line template
+            template = '%i %i ' + template
+
+            coo = a.tocoo() # convert to COOrdinate format
+            I,J,V = coo.row + 1, coo.col + 1, coo.data # change base 0 -> base 1
+
+            if field in (self.FIELD_REAL, self.FIELD_INTEGER):
+                for ijv_tuple in izip(I,J,V):
+                    stream.writelines(template % ijv_tuple)
+            elif field == self.FIELD_COMPLEX:
+                for ijv_tuple in izip(I,J,V.real,V.imag):
+                    stream.writelines(template % ijv_tuple)
+            elif field == self.FIELD_PATTERN:
+                raise NotImplementedError,`field`
+            else:
+                raise TypeError,'Unknown field type %s'% `field`
+
+
+#-------------------------------------------------------------------------------
 if __name__ == '__main__':
     import sys
     import time




More information about the Scipy-svn mailing list