[Scipy-svn] r3675 - branches/io_new
scipy-svn at scipy.org
scipy-svn at scipy.org
Sat Dec 15 23:35:46 EST 2007
Author: brian.hawthorne
Date: 2007-12-15 22:35:42 -0600 (Sat, 15 Dec 2007)
New Revision: 3675
Modified:
branches/io_new/mmio.py
Log:
cleanup, factoring
Modified: branches/io_new/mmio.py
===================================================================
--- branches/io_new/mmio.py 2007-12-16 04:35:23 UTC (rev 3674)
+++ branches/io_new/mmio.py 2007-12-16 04:35:42 UTC (rev 3675)
@@ -14,8 +14,10 @@
empty, concatenate, ones, ascontiguousarray
from itertools import izip
-__all__ = ['mminfo','mmread','mmwrite']
+__all__ = ['mminfo','mmread','mmwrite', 'MMFile']
+
+#-------------------------------------------------------------------------------
def mminfo(source):
""" Queries the contents of the Matrix Market file 'filename' to
extract size and storage information.
@@ -29,51 +31,13 @@
rows,cols - number of matrix rows and columns
entries - number of non-zero entries of a sparse matrix
or rows*cols for a dense matrix
- rep - 'coordinate' | 'array'
+ format - 'coordinate' | 'array'
field - 'real' | 'complex' | 'pattern' | 'integer'
symm - 'general' | 'symmetric' | 'skew-symmetric' | 'hermitian'
"""
- close_it = 0
- if type(source) is type(''):
- if not os.path.isfile(source):
- if source[-4:] != '.mtx':
- source = source + '.mtx'
- source = open(source,'r')
- close_it = 1
- line = source.readline().split()
- if not line[0].startswith('%%MatrixMarket'):
- raise ValueError,'source is not in Matrix Market format'
+ return MMFile.info(source)
- assert len(line)==5,`line`
-
- assert line[1].strip().lower()=='matrix',`line`
-
- rep = line[2].strip().lower()
- if rep=='dense': rep='array'
- elif rep=='sparse': rep='coordinate'
-
- field = line[3].strip().lower()
-
- symm = line[4].strip().lower()
-
- while line:
- line = source.readline()
- if line.startswith('%'):
- continue
- line = line.split()
- if rep=='array':
- assert len(line)==2,`line`
- rows,cols = map(eval,line)
- entries = rows*cols
- else:
- assert len(line)==3,`line`
- rows,cols,entries = map(eval,line)
- break
-
- if close_it:
- source.close()
- return (rows,cols,entries,rep,field,symm)
-
+#-------------------------------------------------------------------------------
def mmread(source):
""" Reads the contents of a Matrix Market file 'filename' into a matrix.
@@ -86,313 +50,525 @@
a - sparse or full matrix
"""
- close_it = 0
- if type(source) is type(''):
- if not os.path.isfile(source):
- if os.path.isfile(source+'.mtx'):
- source = source + '.mtx'
- elif os.path.isfile(source+'.mtx.gz'):
- source = source + '.mtx.gz'
- if source[-3:] == '.gz':
- import gzip
- source = gzip.open(source)
- else:
- source = open(source,'r')
- close_it = 1
+ return MMFile().read(source)
- rows,cols,entries,rep,field,symm = mminfo(source)
+#-------------------------------------------------------------------------------
+def mmwrite(target, a, comment='', field=None, precision=None):
+ """ Writes the sparse or dense matrix A to a Matrix Market formatted file.
- try:
- from scipy.sparse import coo_matrix
- except ImportError:
- coo_matrix = None
+ Inputs:
- if field=='integer':
- dtype='i'
- elif field=='real':
- dtype='d'
- elif field=='complex':
- dtype='D'
- elif field=='pattern':
- dtype='d'
- else:
- raise ValueError,`field`
+ target - Matrix Market filename (extension .mtx) or open file object
+ a - sparse or full matrix
+ comment - comments to be prepended to the Matrix Market file
+ field - 'real' | 'complex' | 'pattern' | 'integer'
+ precision - Number of digits to display for real or complex values.
+ """
+ MMFile().write(target, a, comment, field, precision)
- has_symmetry = symm in ['symmetric','skew-symmetric','hermitian']
- is_complex = field=='complex'
- is_skew = symm=='skew-symmetric'
- is_herm = symm=='hermitian'
- is_pattern = field=='pattern'
- if rep == 'array':
- a = zeros((rows,cols),dtype=dtype)
- line = 1
- i,j = 0,0
- while line:
+################################################################################
+class MMFile (object):
+ __slots__ = (
+ '_rows',
+ '_cols',
+ '_entries',
+ '_format',
+ '_field',
+ '_symmetry')
+
+ @property
+ def rows(self): return self._rows
+ @property
+ def cols(self): return self._cols
+ @property
+ def entries(self): return self._entries
+ @property
+ def format(self): return self._format
+ @property
+ def field(self): return self._field
+ @property
+ def symmetry(self): return self._symmetry
+
+ @property
+ def has_symmetry(self):
+ return self._symmetry in (self.SYMMETRY_SYMMETRIC,
+ self.SYMMETRY_SKEW_SYMMETRIC, self.SYMMETRY_HERMITIAN)
+
+ # format values
+ FORMAT_COORDINATE = 'coordinate'
+ FORMAT_ARRAY = 'array'
+ FORMAT_VALUES = (FORMAT_COORDINATE, FORMAT_ARRAY)
+
+ @classmethod
+ def _validate_format(self, format):
+ if format not in self.FORMAT_VALUES:
+ raise ValueError,'unknown format type %s, must be one of %s'% \
+ (`format`, `self.FORMAT_VALUES`)
+
+ # field values
+ FIELD_INTEGER = 'integer'
+ FIELD_REAL = 'real'
+ FIELD_COMPLEX = 'complex'
+ FIELD_PATTERN = 'pattern'
+ FIELD_VALUES = (FIELD_INTEGER, FIELD_REAL, FIELD_COMPLEX, FIELD_PATTERN)
+
+ @classmethod
+ def _validate_field(self, field):
+ if field not in self.FIELD_VALUES:
+ raise ValueError,'unknown field type %s, must be one of %s'% \
+ (`field`, `self.FIELD_VALUES`)
+
+ # symmetry values
+ SYMMETRY_GENERAL = 'general'
+ SYMMETRY_SYMMETRIC = 'symmetric'
+ SYMMETRY_SKEW_SYMMETRIC = 'skew-symmetric'
+ SYMMETRY_HERMITIAN = 'hermitian'
+ SYMMETRY_VALUES = (
+ SYMMETRY_GENERAL, SYMMETRY_SYMMETRIC, SYMMETRY_SKEW_SYMMETRIC,
+ SYMMETRY_HERMITIAN)
+
+ @classmethod
+ def _validate_symmetry(self, symmetry):
+ if symmetry not in self.SYMMETRY_VALUES:
+ raise ValueError,'unknown symmetry type %s, must be one of %s'% \
+ (`symmetry`, `self.SYMMETRY_VALUES`)
+
+ DTYPES_BY_FIELD = {
+ FIELD_INTEGER: 'i',
+ FIELD_REAL: 'd',
+ FIELD_COMPLEX:'D',
+ FIELD_PATTERN:'d'}
+
+ #---------------------------------------------------------------------------
+ @staticmethod
+ def reader(): pass
+
+ #---------------------------------------------------------------------------
+ @staticmethod
+ def writer(): pass
+
+ #---------------------------------------------------------------------------
+ @classmethod
+ def info(self, source):
+ source, close_it = self._open(source)
+
+ try:
+
+ # read and validate header line
line = source.readline()
- if not line or line.startswith('%'):
- continue
- if is_complex:
- aij = complex(*map(float,line.split()))
+ mmid, matrix, format, field, symmetry = \
+ [part.strip().lower() for part in line.split()]
+ if not mmid.startswith('%%matrixmarket'):
+ raise ValueError,'source is not in Matrix Market format'
+
+ assert matrix == 'matrix',`line`
+
+ # ??? Is this necessary? I don't see 'dense' or 'sparse' in the spec
+ # http://math.nist.gov/MatrixMarket/formats.html
+ if format == 'dense': format = self.FORMAT_ARRAY
+ elif format == 'sparse': format = self.FORMAT_COORDINATE
+
+ # skip comments
+ while line.startswith('%'): line = source.readline()
+
+ line = line.split()
+ if format == self.FORMAT_ARRAY:
+ assert len(line)==2,`line`
+ rows,cols = map(float, line)
+ entries = rows*cols
else:
- aij = float(line)
- a[i,j] = aij
- if has_symmetry and i!=j:
- if is_skew:
- a[j,i] = -aij
- elif is_herm:
- a[j,i] = conj(aij)
+ assert len(line)==3,`line`
+ rows, cols, entries = map(float, line)
+
+ return (rows, cols, entries, format, field, symmetry)
+
+ finally:
+ if close_it: source.close()
+
+ #---------------------------------------------------------------------------
+ @staticmethod
+ def _open(filespec, mode='r'):
+ """
+ Return an open file stream for reading based on source. If source is
+ a file name, open it (after trying to find it with mtx and gzipped mtx
+ extensions). Otherwise, just return source.
+ """
+ close_it = False
+ if type(filespec) is type(''):
+ close_it = True
+
+ # open for reading
+ if mode[0] == 'r':
+
+ # determine filename plus extension
+ if not os.path.isfile(filespec):
+ if os.path.isfile(filespec+'.mtx'):
+ filespec = filespec + '.mtx'
+ elif os.path.isfile(filespec+'.mtx.gz'):
+ filespec = filespec + '.mtx.gz'
+ # open filename
+ if filespec[-3:] == '.gz':
+ import gzip
+ stream = gzip.open(filespec, mode)
else:
- a[j,i] = aij
- if i<rows-1:
- i = i + 1
+ stream = open(filespec, mode)
+
+ # open for writing
else:
- j = j + 1
- if not has_symmetry:
- i = 0
- else:
- i = j
- assert i in [0,j] and j==cols,`i,j,rows,cols`
+ if filespec[-4:] != '.mtx':
+ filespec = filespec + '.mtx'
+ stream = open(filespec, mode)
+ else:
+ stream = filespec
- elif rep=='coordinate' and coo_matrix is None:
- # Read sparse matrix to dense when coo_matrix is not available.
- a = zeros((rows,cols), dtype=dtype)
- line = 1
- k = 0
- while line:
- line = source.readline()
- if not line or line.startswith('%'):
- continue
- l = line.split()
- i,j = map(int,l[:2])
- i,j = i-1,j-1
- if is_complex:
- aij = complex(*map(float,l[2:]))
- else:
- aij = float(l[2])
- a[i,j] = aij
- if has_symmetry and i!=j:
- if is_skew:
- a[j,i] = -aij
- elif is_herm:
- a[j,i] = conj(aij)
- else:
- a[j,i] = aij
- k = k + 1
- assert k==entries,`k,entries`
+ return stream, close_it
- elif rep=='coordinate':
- from numpy import fromfile,fromstring
+ #---------------------------------------------------------------------------
+ @staticmethod
+ def _get_symmetry(a):
+ m,n = a.shape
+ if m!=n:
+ return MMFile.SYMMETRY_GENERAL
+ issymm = 1
+ isskew = 1
+ isherm = a.dtype.char in 'FD'
+ for j in range(n):
+ for i in range(j+1,n):
+ aij,aji = a[i][j],a[j][i]
+ if issymm and aij != aji:
+ issymm = 0
+ if isskew and aij != -aji:
+ isskew = 0
+ if isherm and aij != conj(aji):
+ isherm = 0
+ if not (issymm or isskew or isherm):
+ break
+ if issymm: return MMFile.SYMMETRY_SYMMETRIC
+ if isskew: return MMFile.SYMMETRY_SKEW_SYMMETRIC
+ if isherm: return MMFile.SYMMETRY_HERMITIAN
+ return MMFile.SYMMETRY_GENERAL
+
+ #---------------------------------------------------------------------------
+ @staticmethod
+ def _field_template(field, precision):
+ return {
+ MMFile.FIELD_REAL: '%%.%ie\n' % precision,
+ MMFile.FIELD_INTEGER: '%i\n',
+ MMFile.FIELD_COMPLEX: '%%.%ie %%.%ie\n' % (precision,precision)
+ }.get(field, None)
+
+ #---------------------------------------------------------------------------
+ def __init__(self, **kwargs): self._init_attrs(**kwargs)
+
+ #---------------------------------------------------------------------------
+ def read(self, source):
+ stream, close_it = self._open(source)
+
try:
- # fromfile works for normal files
- flat_data = fromfile(source,sep=' ')
- except:
- # fallback - fromfile fails for some file-like objects
- flat_data = fromstring(source.read(),sep=' ')
-
- # TODO use iterator (e.g. xreadlines) to avoid reading
- # the whole file into memory
-
- if is_pattern:
- flat_data = flat_data.reshape(-1,2)
- I = ascontiguousarray(flat_data[:,0], dtype='intc')
- J = ascontiguousarray(flat_data[:,1], dtype='intc')
- V = ones(len(I))
- elif is_complex:
- flat_data = flat_data.reshape(-1,4)
- I = ascontiguousarray(flat_data[:,0], dtype='intc')
- J = ascontiguousarray(flat_data[:,1], dtype='intc')
- V = ascontiguousarray(flat_data[:,2], dtype='complex')
- V.imag = flat_data[:,3]
- else:
- flat_data = flat_data.reshape(-1,3)
- I = ascontiguousarray(flat_data[:,0], dtype='intc')
- J = ascontiguousarray(flat_data[:,1], dtype='intc')
- V = ascontiguousarray(flat_data[:,2], dtype='float')
+ self._parse_header(stream)
+ return self._parse_body(stream)
- I -= 1 #adjust indices (base 1 -> base 0)
- J -= 1
+ finally:
+ if close_it: stream.close()
- if has_symmetry:
- mask = (I != J) #off diagonal mask
- od_I = I[mask]
- od_J = J[mask]
- od_V = V[mask]
+ #---------------------------------------------------------------------------
+ def write(self, target, a, comment='', field=None, precision=None):
+ stream, close_it = self._open(target, 'w')
- I = concatenate((I,od_J))
- J = concatenate((J,od_I))
+ try:
+ self._write(stream, a, comment, field, precision)
- if is_skew:
- od_V *= -1
- elif is_herm:
- od_V = od_V.conjugate()
+ finally:
+ if close_it: stream.close()
+ else: stream.flush()
- V = concatenate((V,od_V))
- a = coo_matrix((V, (I, J)), dims=(rows, cols), dtype=dtype)
- else:
- raise NotImplementedError,`rep`
+ #---------------------------------------------------------------------------
+ def _init_attrs(self, **kwargs):
+ """
+ Initialize each attributes with the corresponding keyword arg value
+ or a default of None
+ """
+ attrs = self.__class__.__slots__
+ public_attrs = [attr[1:] for attr in attrs]
+ invalid_keys = set(kwargs.keys()) - set(public_attrs)
+
+ if invalid_keys:
+ raise ValueError, \
+ 'found %s invalid keyword arguments, please only use %s' % \
+ (`tuple(invalid_keys)`, `public_attrs`)
- if close_it:
- source.close()
- return a
+ for attr in attrs: setattr(self, attr, kwargs.get(attr[1:], None))
-def mmwrite(target,a,comment='',field=None,precision=None):
- """ Writes the sparse or dense matrix A to a Matrix Market formatted file.
+ #---------------------------------------------------------------------------
+ def _parse_header(self, stream):
+ rows, cols, entries, format, field, symmetry = \
+ self.__class__.info(stream)
+ self._init_attrs(rows=rows, cols=cols, entries=entries, format=format,
+ field=field, symmetry=symmetry)
- Inputs:
+ #---------------------------------------------------------------------------
+ def _parse_body(self, stream):
+ rows, cols, entries, format, field, symm = \
+ (self.rows, self.cols, self.entries, self.format, self.field, self.symmetry)
- target - Matrix Market filename (extension .mtx) or open file object
- a - sparse or full matrix
- comment - comments to be prepended to the Matrix Market file
- field - 'real' | 'complex' | 'pattern' | 'integer'
- precision - Number of digits to display for real or complex values.
- """
- close_it = 0
- if type(target) is type(''):
- if target[-4:] != '.mtx':
- target = target + '.mtx'
- target = open(target,'w')
- close_it = 1
+ try:
+ from scipy.sparse import coo_matrix
+ except ImportError:
+ coo_matrix = None
- if isinstance(a, list) or isinstance(a, ndarray) or isinstance(a, tuple) or hasattr(a,'__array__'):
- rep = 'array'
- a = asarray(a)
- if len(a.shape) != 2:
- raise ValueError, 'expected matrix'
- rows,cols = a.shape
- entries = rows*cols
- typecode = a.dtype.char
- if field is not None:
- if field=='integer':
- a = a.astype('i')
- elif field=='real':
- if typecode not in 'fd':
- a = a.astype('d')
- elif field=='complex':
- if typecode not in 'FD':
- a = a.astype('D')
- elif field=='pattern':
- pass
+ dtype = self.DTYPES_BY_FIELD.get(field, None)
+
+ has_symmetry = self.has_symmetry
+ is_complex = field == self.FIELD_COMPLEX
+ is_skew = symm == self.SYMMETRY_SKEW_SYMMETRIC
+ is_herm = symm == self.SYMMETRY_HERMITIAN
+ is_pattern = field == self.FIELD_PATTERN
+
+ if format == self.FORMAT_ARRAY:
+ a = zeros((rows,cols),dtype=dtype)
+ line = 1
+ i,j = 0,0
+ while line:
+ line = stream.readline()
+ if not line or line.startswith('%'):
+ continue
+ if is_complex:
+ aij = complex(*map(float,line.split()))
+ else:
+ aij = float(line)
+ a[i,j] = aij
+ if has_symmetry and i!=j:
+ if is_skew:
+ a[j,i] = -aij
+ elif is_herm:
+ a[j,i] = conj(aij)
+ else:
+ a[j,i] = aij
+ if i<rows-1:
+ i = i + 1
+ else:
+ j = j + 1
+ if not has_symmetry:
+ i = 0
+ else:
+ i = j
+ assert i in [0,j] and j==cols,`i,j,rows,cols`
+
+ elif format == self.FORMAT_COORDINATE and coo_matrix is None:
+ # Read sparse matrix to dense when coo_matrix is not available.
+ a = zeros((rows,cols), dtype=dtype)
+ line = 1
+ k = 0
+ while line:
+ line = stream.readline()
+ if not line or line.startswith('%'):
+ continue
+ l = line.split()
+ i,j = map(int,l[:2])
+ i,j = i-1,j-1
+ if is_complex:
+ aij = complex(*map(float,l[2:]))
+ else:
+ aij = float(l[2])
+ a[i,j] = aij
+ if has_symmetry and i!=j:
+ if is_skew:
+ a[j,i] = -aij
+ elif is_herm:
+ a[j,i] = conj(aij)
+ else:
+ a[j,i] = aij
+ k = k + 1
+ assert k==entries,`k,entries`
+
+ elif format == self.FORMAT_COORDINATE:
+ from numpy import fromfile,fromstring
+ try:
+ # fromfile works for normal files
+ flat_data = fromfile(stream,sep=' ')
+ except:
+ # fallback - fromfile fails for some file-like objects
+ flat_data = fromstring(stream.read(),sep=' ')
+
+ # TODO use iterator (e.g. xreadlines) to avoid reading
+ # the whole file into memory
+
+ if is_pattern:
+ flat_data = flat_data.reshape(-1,2)
+ I = ascontiguousarray(flat_data[:,0], dtype='intc')
+ J = ascontiguousarray(flat_data[:,1], dtype='intc')
+ V = ones(len(I))
+ elif is_complex:
+ flat_data = flat_data.reshape(-1,4)
+ I = ascontiguousarray(flat_data[:,0], dtype='intc')
+ J = ascontiguousarray(flat_data[:,1], dtype='intc')
+ V = ascontiguousarray(flat_data[:,2], dtype='complex')
+ V.imag = flat_data[:,3]
else:
- raise ValueError,'unknown field '+field
- typecode = a.dtype.char
- else:
- rep = 'coordinate'
- from scipy.sparse import spmatrix
- if not isinstance(a,spmatrix):
- raise ValueError,'unknown matrix type ' + `type(a)`
- rows,cols = a.shape
- entries = a.getnnz()
- typecode = a.dtype.char
+ flat_data = flat_data.reshape(-1,3)
+ I = ascontiguousarray(flat_data[:,0], dtype='intc')
+ J = ascontiguousarray(flat_data[:,1], dtype='intc')
+ V = ascontiguousarray(flat_data[:,2], dtype='float')
- if precision is None:
- if typecode in 'fF':
- precision = 8
+ I -= 1 #adjust indices (base 1 -> base 0)
+ J -= 1
+
+ if has_symmetry:
+ mask = (I != J) #off diagonal mask
+ od_I = I[mask]
+ od_J = J[mask]
+ od_V = V[mask]
+
+ I = concatenate((I,od_J))
+ J = concatenate((J,od_I))
+
+ if is_skew:
+ od_V *= -1
+ elif is_herm:
+ od_V = od_V.conjugate()
+
+ V = concatenate((V,od_V))
+
+ a = coo_matrix((V, (I, J)), dims=(rows, cols), dtype=dtype)
else:
- precision = 16
- if field is None:
- if typecode in 'li':
- field = 'integer'
- elif typecode in 'df':
- field = 'real'
- elif typecode in 'DF':
- field = 'complex'
- else:
- raise TypeError,'unexpected typecode '+typecode
+ raise NotImplementedError,`format`
- if rep == 'array':
- symm = _get_symmetry(a)
- else:
- symm = 'general'
+ return a
- target.write('%%%%MatrixMarket matrix %s %s %s\n' % (rep,field,symm))
+ #---------------------------------------------------------------------------
+ def _write(self, stream, a, comment='', field=None, precision=None):
- for line in comment.split('\n'):
- target.write('%%%s\n' % (line))
+ if isinstance(a, list) or isinstance(a, ndarray) or isinstance(a, tuple) or hasattr(a,'__array__'):
+ rep = self.FORMAT_ARRAY
+ a = asarray(a)
+ if len(a.shape) != 2:
+ raise ValueError, 'expected matrix'
+ rows,cols = a.shape
+ entries = rows*cols
- if field in ['real','integer']:
- if field=='real':
- format = '%%.%ie\n' % precision
+ if field is not None:
+
+ if field == self.FIELD_INTEGER:
+ a = a.astype('i')
+ elif field == self.FIELD_REAL:
+ if a.dtype.char not in 'fd':
+ a = a.astype('d')
+ elif field == self.FIELD_COMPLEX:
+ if a.dtype.char not in 'FD':
+ a = a.astype('D')
+
else:
- format = '%i\n'
- elif field=='complex':
- format = '%%.%ie %%.%ie\n' % (precision,precision)
+ from scipy.sparse import spmatrix
+ if not isinstance(a,spmatrix):
+ raise ValueError,'unknown matrix type ' + `type(a)`
+ rep = 'coordinate'
+ rows, cols = a.shape
+ entries = a.getnnz()
- if rep == 'array':
- target.write('%i %i\n' % (rows,cols))
- if field in ['real','integer']:
- if symm=='general':
- for j in range(cols):
- for i in range(rows):
- target.write(format % a[i,j])
+ typecode = a.dtype.char
+
+ if precision is None:
+ if typecode in 'fF':
+ precision = 8
else:
- for j in range(cols):
- for i in range(j,rows):
- target.write(format % a[i,j])
- elif field=='complex':
- if symm=='general':
- for j in range(cols):
- for i in range(rows):
- aij = a[i,j]
- target.write(format % (real(aij),imag(aij)))
+ precision = 16
+
+ if field is None:
+ if typecode in 'li':
+ field = 'integer'
+ elif typecode in 'df':
+ field = 'real'
+ elif typecode in 'DF':
+ field = 'complex'
else:
- for j in range(cols):
- for i in range(j,rows):
- aij = a[i,j]
- target.write(format % (real(aij),imag(aij)))
- elif field=='pattern':
- raise ValueError,'Pattern type inconsisted with dense matrix'
+ raise TypeError,'unexpected typecode '+typecode
+
+ if rep == self.FORMAT_ARRAY:
+ symm = self._get_symmetry(a)
else:
- raise TypeError,'Unknown matrix type '+`field`
- else:
- format = '%i %i ' + format
- target.write('%i %i %i\n' % (rows,cols,entries))
- assert symm=='general',`symm`
+ symm = self.SYMMETRY_GENERAL
- coo = a.tocoo() # convert to COOrdinate format
- I,J,V = coo.row + 1, coo.col + 1, coo.data # change base 0 -> base 1
+ # validate rep, field, and symmetry
+ self.__class__._validate_format(rep)
+ self.__class__._validate_field(field)
+ self.__class__._validate_symmetry(symm)
- if field in ['real','integer']:
- for ijv_tuple in izip(I,J,V):
- target.writelines(format % ijv_tuple)
- elif field=='complex':
- for ijv_tuple in izip(I,J,V.real,V.imag):
- target.writelines(format % ijv_tuple)
- elif field=='pattern':
- raise NotImplementedError,`field`
+ # write initial header line
+ stream.write('%%%%MatrixMarket matrix %s %s %s\n' % (rep,field,symm))
+
+ # write comments
+ for line in comment.split('\n'):
+ stream.write('%%%s\n' % (line))
+
+
+ template = self._field_template(field, precision)
+
+ # write dense format
+ if rep == self.FORMAT_ARRAY:
+
+ # write shape spec
+ stream.write('%i %i\n' % (rows,cols))
+
+ if field in (self.FIELD_INTEGER, self.FIELD_REAL):
+
+ if symm == self.SYMMETRY_GENERAL:
+ for j in range(cols):
+ for i in range(rows):
+ stream.write(template % a[i,j])
+ else:
+ for j in range(cols):
+ for i in range(j,rows):
+ stream.write(template % a[i,j])
+
+ elif field == self.FIELD_COMPLEX:
+
+ if symm == self.SYMMETRY_GENERAL:
+ for j in range(cols):
+ for i in range(rows):
+ aij = a[i,j]
+ stream.write(template % (real(aij),imag(aij)))
+ else:
+ for j in range(cols):
+ for i in range(j,rows):
+ aij = a[i,j]
+ stream.write(template % (real(aij),imag(aij)))
+
+ elif field == self.FIELD_PATTERN:
+ raise ValueError,'pattern type inconsisted with dense format'
+
+ else:
+ raise TypeError,'Unknown field type %s'% `field`
+
+ # write sparse format
else:
- raise TypeError,'Unknown matrix type '+`field`
- if close_it:
- target.close()
- else:
- target.flush()
- return
+ if symm != self.SYMMETRY_GENERAL:
+ raise ValueError, 'symmetric matrices incompatible with sparse format'
-def _get_symmetry(a):
- m,n = a.shape
- if m!=n:
- return 'general'
- issymm = 1
- isskew = 1
- isherm = a.dtype.char in 'FD'
- for j in range(n):
- for i in range(j+1,n):
- aij,aji = a[i][j],a[j][i]
- if issymm and aij != aji:
- issymm = 0
- if isskew and aij != -aji:
- isskew = 0
- if isherm and aij != conj(aji):
- isherm = 0
- if not (issymm or isskew or isherm):
- break
- if issymm: return 'symmetric'
- if isskew: return 'skew-symmetric'
- if isherm: return 'hermitian'
- return 'general'
+ # write shape spec
+ stream.write('%i %i %i\n' % (rows,cols,entries))
+ # line template
+ template = '%i %i ' + template
+
+ coo = a.tocoo() # convert to COOrdinate format
+ I,J,V = coo.row + 1, coo.col + 1, coo.data # change base 0 -> base 1
+
+ if field in (self.FIELD_REAL, self.FIELD_INTEGER):
+ for ijv_tuple in izip(I,J,V):
+ stream.writelines(template % ijv_tuple)
+ elif field == self.FIELD_COMPLEX:
+ for ijv_tuple in izip(I,J,V.real,V.imag):
+ stream.writelines(template % ijv_tuple)
+ elif field == self.FIELD_PATTERN:
+ raise NotImplementedError,`field`
+ else:
+ raise TypeError,'Unknown field type %s'% `field`
+
+
+#-------------------------------------------------------------------------------
if __name__ == '__main__':
import sys
import time
More information about the Scipy-svn
mailing list