[Scipy-svn] r6356 - in trunk/scipy/sparse/linalg/dsolve: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Tue Apr 27 17:57:54 EDT 2010
Author: ptvirtan
Date: 2010-04-27 16:57:54 -0500 (Tue, 27 Apr 2010)
New Revision: 6356
Modified:
trunk/scipy/sparse/linalg/dsolve/_superluobject.c
trunk/scipy/sparse/linalg/dsolve/tests/test_linsolve.py
Log:
ENH: sparse.linalg.dsolve: expose perm_* attributes of the splu object to Python side (patch from #937)
Modified: trunk/scipy/sparse/linalg/dsolve/_superluobject.c
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/_superluobject.c 2010-04-27 21:57:39 UTC (rev 6355)
+++ trunk/scipy/sparse/linalg/dsolve/_superluobject.c 2010-04-27 21:57:54 UTC (rev 6356)
@@ -136,8 +136,22 @@
return Py_BuildValue("(i,i)", self->m, self->n);
if (strcmp(name, "nnz") == 0)
return Py_BuildValue("i", ((SCformat *)self->L.Store)->nnz + ((SCformat *)self->U.Store)->nnz);
+ if (strcmp(name, "perm_r") == 0) {
+ PyArrayObject* perm_r = PyArray_SimpleNewFromData(1, (npy_intp*) (&self->n), NPY_INT, (void*)self->perm_r);
+ /* For ref counting of the memory */
+ PyArray_BASE(perm_r) = self;
+ Py_INCREF(self);
+ return perm_r ;
+ }
+ if (strcmp(name, "perm_c") == 0) {
+ PyArrayObject* perm_c = PyArray_SimpleNewFromData(1, (npy_intp*) (&self->n), NPY_INT, (void*)self->perm_c);
+ /* For ref counting of the memory */
+ PyArray_BASE(perm_c) = self;
+ Py_INCREF(self);
+ return perm_c ;
+ }
if (strcmp(name, "__members__") == 0) {
- char *members[] = {"shape", "nnz"};
+ char *members[] = {"shape", "nnz", "perm_r", "perm_c"};
int i;
PyObject *list = PyList_New(sizeof(members)/sizeof(char *));
@@ -158,6 +172,27 @@
/***********************************************************************
* SciPySuperLUType structure
*/
+static char factored_lu_doc[] = "\
+Object resulting from a factorization of a sparse matrix\n\
+\n\
+Attributes\n\
+-----------\n\
+\n\
+shape : 2-tuple\n\
+ the shape of the orginal matrix factored\n \
+nnz : int\n\
+ the number of non zero coefficient of the matrix\n \
+perm_c\n\
+ the permutation applied to the colums of the matrix for the LU factorization\n\
+perm_r\n\
+ the permutation applied to the rows of the matrix for the LU factorization\n\
+\n\
+Methods\n\
+-------\n\
+solve\n\
+ solves the system for a given right hand side vector\n \
+\n\
+";
PyTypeObject SciPySuperLUType = {
PyObject_HEAD_INIT(NULL)
@@ -175,6 +210,13 @@
0, /* tp_as_sequence*/
0, /* tp_as_mapping*/
0, /* tp_hash */
+ 0, /* tp_call */
+ 0, /* tp_str */
+ 0, /* tp_getattro */
+ 0, /* tp_setattro */
+ 0, /* tp_as_buffer */
+ 0, /* tp_flags */
+ factored_lu_doc, /* tp_doc */
};
Modified: trunk/scipy/sparse/linalg/dsolve/tests/test_linsolve.py
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/tests/test_linsolve.py 2010-04-27 21:57:39 UTC (rev 6355)
+++ trunk/scipy/sparse/linalg/dsolve/tests/test_linsolve.py 2010-04-27 21:57:54 UTC (rev 6356)
@@ -1,11 +1,11 @@
import warnings
-from numpy import array, finfo, arange
+from numpy import array, finfo, arange, eye, all, unique, ones, dot
import numpy.random as random
from numpy.testing import *
from scipy.linalg import norm, inv
-from scipy.sparse import spdiags, SparseEfficiencyWarning
+from scipy.sparse import spdiags, SparseEfficiencyWarning, csc_matrix
from scipy.sparse.linalg.dsolve import spsolve, use_solver, splu, spilu
warnings.simplefilter('ignore',SparseEfficiencyWarning)
@@ -14,11 +14,10 @@
use_solver( useUmfpack = False )
class TestLinsolve(TestCase):
- ## this crashes SuperLU
- #def test_singular(self):
- # A = csc_matrix( (5,5), dtype='d' )
- # b = array([1, 2, 3, 4, 5],dtype='d')
- # x = spsolve(A,b)
+ def test_singular(self):
+ A = csc_matrix( (5,5), dtype='d' )
+ b = array([1, 2, 3, 4, 5],dtype='d')
+ x = spsolve(A, b, use_umfpack=False)
def test_twodiags(self):
A = spdiags([[1, 2, 3, 4, 5], [6, 5, 8, 9, 10]], [0, 1], 5, 5)
@@ -48,18 +47,89 @@
self.A = spdiags((d, 2*d, d[::-1]), (-3, 0, 5), n, n)
random.seed(1234)
- def test_splu(self):
+ def test_splu_smoketest(self):
+ # Check that splu works at all
x = random.rand(self.n)
lu = splu(self.A)
r = self.A*lu.solve(x)
assert abs(x - r).max() < 1e-13
- def test_spilu(self):
+ def test_spilu_smoketest(self):
+ # Check that spilu works at all
x = random.rand(self.n)
lu = spilu(self.A, drop_tol=1e-2, fill_factor=5)
r = self.A*lu.solve(x)
assert abs(x - r).max() < 1e-2
assert abs(x - r).max() > 1e-5
+ def test_splu_nnz0(self):
+ A = csc_matrix( (5,5), dtype='d' )
+ assert_raises(RuntimeError, splu, A)
+
+ def test_spilu_nnz0(self):
+ A = csc_matrix( (5,5), dtype='d' )
+ assert_raises(RuntimeError, spilu, A)
+
+ def test_splu_basic(self):
+ # Test basic splu functionality.
+ n = 30
+ a = random.random((n, n))
+ a[a < 0.95] = 0
+ # First test with a singular matrix
+ a[:, 0] = 0
+ a_ = csc_matrix(a)
+ # Matrix is exactly singular
+ assert_raises(RuntimeError, splu, a_)
+
+ # Make a diagonal dominant, to make sure it is not singular
+ a += 4*eye(n)
+ a_ = csc_matrix(a)
+ lu = splu(a_)
+ b = ones(n)
+ x = lu.solve(b)
+ assert_almost_equal(dot(a, x), b)
+
+ def test_splu_perm(self):
+ # Test the permutation vectors exposed by splu.
+ n = 30
+ a = random.random((n, n))
+ a[a < 0.95] = 0
+ # Make a diagonal dominant, to make sure it is not singular
+ a += 4*eye(n)
+ a_ = csc_matrix(a)
+ lu = splu(a_)
+ # Check that the permutation indices do belong to [0, n-1].
+ for perm in (lu.perm_r, lu.perm_c):
+ assert_(all(perm > -1))
+ assert_(all(perm < n))
+ assert_equal(len(unique(perm)), len(perm))
+
+ # Now make a symmetric, and test that the two permutation vectors are
+ # the same
+ a += a.T
+ a_ = csc_matrix(a)
+ lu = splu(a_)
+ assert_array_equal(lu.perm_r, lu.perm_c)
+
+ def test_lu_refcount(self):
+ # Test that we are keeping track of the reference count with splu.
+ n = 30
+ a = random.random((n, n))
+ a[a < 0.95] = 0
+ # Make a diagonal dominant, to make sure it is not singular
+ a += 4*eye(n)
+ a_ = csc_matrix(a)
+ lu = splu(a_)
+
+ # And now test that we don't have a refcount bug
+ import gc, sys
+ rc = sys.getrefcount(lu)
+ for attr in ('perm_r', 'perm_c'):
+ perm = getattr(lu, attr)
+ assert_equal(sys.getrefcount(lu), rc + 1)
+ del perm
+ assert_equal(sys.getrefcount(lu), rc)
+
+
if __name__ == "__main__":
run_module_suite()
More information about the Scipy-svn
mailing list