[Scipy-svn] r6350 - trunk/scipy/sparse/linalg/dsolve
scipy-svn at scipy.org
scipy-svn at scipy.org
Tue Apr 27 17:56:20 EDT 2010
Author: ptvirtan
Date: 2010-04-27 16:56:19 -0500 (Tue, 27 Apr 2010)
New Revision: 6350
Modified:
trunk/scipy/sparse/linalg/dsolve/_superlumodule.c
trunk/scipy/sparse/linalg/dsolve/_superluobject.c
trunk/scipy/sparse/linalg/dsolve/_superluobject.h
trunk/scipy/sparse/linalg/dsolve/linsolve.py
Log:
ENH: sparse.linalg.dsolve: allow passing any SuperLU options to the factorized in the internal gstrf interface. Enable support for incomplete LU
Modified: trunk/scipy/sparse/linalg/dsolve/_superlumodule.c
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/_superlumodule.c 2010-04-27 21:56:00 UTC (rev 6349)
+++ trunk/scipy/sparse/linalg/dsolve/_superlumodule.c 2010-04-27 21:56:19 UTC (rev 6350)
@@ -33,21 +33,22 @@
PyArrayObject *colind=NULL, *rowptr=NULL;
int N, nnz;
int info;
- int csc=0, permc_spec=2;
+ int csc=0;
int *perm_r=NULL, *perm_c=NULL;
SuperMatrix A, B, L, U;
superlu_options_t options;
SuperLUStat_t stat;
+ PyObject *option_dict = NULL;
int type;
static char *kwlist[] = {"N","nnz","nzvals","colind","rowptr","B", "csc",
- "permc_spec",NULL};
+ "options",NULL};
/* Get input arguments */
- if (!PyArg_ParseTupleAndKeywords(args, kwdict, "iiO!O!O!O|ii", kwlist,
+ if (!PyArg_ParseTupleAndKeywords(args, kwdict, "iiO!O!O!O|iO", kwlist,
&N, &nnz, &PyArray_Type, &nzvals,
&PyArray_Type, &colind, &PyArray_Type,
- &rowptr, &Py_B, &csc, &permc_spec)) {
+ &rowptr, &Py_B, &csc, &option_dict)) {
return NULL;
}
@@ -64,6 +65,10 @@
return NULL;
}
+ if (!set_superlu_options_from_dict(&options, 0, option_dict)) {
+ return NULL;
+ }
+
/* Create Space for output */
Py_X = PyArray_CopyFromObject(Py_B, type, 1, 2);
if (Py_X == NULL) return NULL;
@@ -99,8 +104,6 @@
else {
perm_c = intMalloc(N);
perm_r = intMalloc(N);
- set_default_options(&options);
- options.ColPerm = superlu_module_getpermc(permc_spec);
StatInit(&stat);
/* Compute direct inverse of sparse Matrix */
@@ -133,30 +136,30 @@
Py_gstrf(PyObject *self, PyObject *args, PyObject *keywds)
{
/* default value for SuperLU parameters*/
- double diag_pivot_thresh = 1.0;
int relax = 1;
int panel_size = 10;
- int permc_spec = 2;
int N, nnz;
PyArrayObject *rowind, *colptr, *nzvals;
SuperMatrix A;
PyObject *result;
+ PyObject *option_dict = NULL;
int type;
-
+ int ilu = 0;
+
static char *kwlist[] = {"N","nnz","nzvals","rowind","colptr",
- "permc_spec","diag_pivot_thresh",
- "relax", "panel_size", NULL};
+ "options", "relax", "panel_size", "ilu",
+ NULL};
int res = PyArg_ParseTupleAndKeywords(
- args, keywds, "iiO!O!O!|iddii", kwlist,
+ args, keywds, "iiO!O!O!|Oiii", kwlist,
&N, &nnz,
&PyArray_Type, &nzvals,
&PyArray_Type, &rowind,
&PyArray_Type, &colptr,
- &permc_spec,
- &diag_pivot_thresh,
+ &option_dict,
&relax,
- &panel_size);
+ &panel_size,
+ &ilu);
if (!res)
return NULL;
@@ -179,8 +182,8 @@
goto fail;
}
- result = newSciPyLUObject(&A, diag_pivot_thresh, relax,
- panel_size, permc_spec, type);
+ result = newSciPyLUObject(&A, relax,
+ panel_size, option_dict, type, ilu);
if (result == NULL) {
goto fail;
}
@@ -217,24 +220,17 @@
\n\
additional keyword arguments:\n\
-----------------------------\n\
-permc_spec specifies the matrix ordering used for the factorization\n\
- 0: natural ordering\n\
- 1: MMD applied to the structure of A^T * A\n\
- 2: MMD applied to the structure of A^T + A\n\
- 3: COLAMD, approximate minimum degree column ordering\n\
- (default: 2)\n\
+options specifies additional options for SuperLU\n\
+ (same keys and values as in superlu_options_t C structure)\n\
\n\
-diag_pivot_thresh threshhold for partial pivoting.\n\
- 0.0 <= diag_pivot_thresh <= 1.0\n\
- 0.0 corresponds to no pivoting\n\
- 1.0 corresponds to partial pivoting\n\
- (default: 1.0)\n\
-\n\
relax to control degree of relaxing supernodes\n\
(default: 1)\n\
\n\
panel_size a panel consist of at most panel_size consecutive columns.\n\
(default: 10)\n\
+\n\
+ilu whether to perform an incomplete LU decomposition\n\
+ (default: false)\n\
";
Modified: trunk/scipy/sparse/linalg/dsolve/_superluobject.c
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/_superluobject.c 2010-04-27 21:56:00 UTC (rev 6349)
+++ trunk/scipy/sparse/linalg/dsolve/_superluobject.c 2010-04-27 21:56:19 UTC (rev 6350)
@@ -6,6 +6,7 @@
extern jmp_buf _superlu_py_jmpbuf;
+
/***********************************************************************
* SciPyLUObject methods
*/
@@ -268,26 +269,9 @@
return 0;
}
-colperm_t superlu_module_getpermc(int permc_spec)
-{
- switch(permc_spec) {
- case 0:
- return NATURAL;
- case 1:
- return MMD_ATA;
- case 2:
- return MMD_AT_PLUS_A;
- case 3:
- return COLAMD;
- }
- ABORT("Invalid input for permc_spec.");
- return NATURAL; /* compiler complains... */
-}
-
PyObject *
-newSciPyLUObject(SuperMatrix *A, double diag_pivot_thresh,
- int relax, int panel_size, int permc_spec,
- int intype)
+newSciPyLUObject(SuperMatrix *A, int relax, int panel_size,
+ PyObject *option_dict, int intype, int ilu)
{
/* A must be in SLU_NC format used by the factorization routine. */
@@ -302,6 +286,10 @@
n = A->ncol;
+ if (!set_superlu_options_from_dict(&options, ilu, option_dict)) {
+ return NULL;
+ }
+
/* Create SciPyLUObject */
self = PyObject_New(SciPyLUObject, &SciPySuperLUType);
if (self == NULL)
@@ -318,13 +306,9 @@
etree = intMalloc(n);
self->perm_r = intMalloc(n);
self->perm_c = intMalloc(n);
+ StatInit(&stat);
- set_default_options(&options);
- options.ColPerm=superlu_module_getpermc(permc_spec);
- options.DiagPivotThresh = diag_pivot_thresh;
- StatInit(&stat);
-
- get_perm_c(permc_spec, A, self->perm_c); /* calc column permutation */
+ get_perm_c(options.ColPerm, A, self->perm_c); /* calc column permutation */
sp_preorder(&options, A, self->perm_c, etree, &AC); /* apply column
* permutation */
@@ -333,10 +317,18 @@
PyErr_SetString(PyExc_ValueError, "Invalid type in SuperMatrix.");
goto fail;
}
- gstrf(SLU_TYPECODE_TO_NPY(A->Dtype),
- &options, &AC, relax, panel_size,
- etree, NULL, lwork, self->perm_c, self->perm_r,
- &self->L, &self->U, &stat, &info);
+ if (ilu) {
+ gsitrf(SLU_TYPECODE_TO_NPY(A->Dtype),
+ &options, &AC, relax, panel_size,
+ etree, NULL, lwork, self->perm_c, self->perm_r,
+ &self->L, &self->U, &stat, &info);
+ }
+ else {
+ gstrf(SLU_TYPECODE_TO_NPY(A->Dtype),
+ &options, &AC, relax, panel_size,
+ etree, NULL, lwork, self->perm_c, self->perm_r,
+ &self->L, &self->U, &stat, &info);
+ }
if (info) {
if (info < 0)
@@ -365,3 +357,161 @@
SciPyLU_dealloc(self);
return NULL;
}
+
+
+/***********************************************************************
+ * Preparing superlu_options_t
+ */
+
+#define ENUM_CHECK_INIT \
+ long i = -1; \
+ char *s = ""; \
+ if (PyString_Check(input)) { \
+ s = PyString_AS_STRING(input); \
+ } \
+ if (PyInt_Check(input)) { \
+ i = PyInt_AsLong(input); \
+ }
+
+#define ENUM_CHECK_FINISH \
+ PyErr_SetString(PyExc_ValueError, "unknown value"); \
+ return 0;
+
+#define ENUM_CHECK(name) \
+ if (strcmp(s, #name) == 0 || i == (long)name) { *value = name; return 1; }
+
+static int yes_no_cvt(PyObject *input, yes_no_t *value)
+{
+ if (input == Py_True) {
+ *value = YES;
+ } else if (input == Py_False) {
+ *value = NO;
+ } else {
+ PyErr_SetString(PyExc_ValueError, "value not a boolean");
+ return 0;
+ }
+ return 1;
+}
+
+static int fact_cvt(PyObject *input, fact_t *value)
+{
+ ENUM_CHECK_INIT;
+ ENUM_CHECK(DOFACT);
+ ENUM_CHECK(SamePattern);
+ ENUM_CHECK(SamePattern_SameRowPerm);
+ ENUM_CHECK(FACTORED);
+ ENUM_CHECK_FINISH;
+}
+
+static int rowperm_cvt(PyObject *input, rowperm_t *value)
+{
+ ENUM_CHECK_INIT;
+ ENUM_CHECK(NOROWPERM);
+ ENUM_CHECK(LargeDiag);
+ ENUM_CHECK(MY_PERMR);
+ ENUM_CHECK_FINISH;
+}
+
+static int colperm_cvt(PyObject *input, colperm_t *value)
+{
+ ENUM_CHECK_INIT;
+ ENUM_CHECK(NATURAL);
+ ENUM_CHECK(MMD_ATA);
+ ENUM_CHECK(MMD_AT_PLUS_A);
+ ENUM_CHECK(COLAMD);
+ ENUM_CHECK(MY_PERMC);
+ ENUM_CHECK_FINISH;
+}
+
+static int trans_cvt(PyObject *input, trans_t *value)
+{
+ ENUM_CHECK_INIT;
+ ENUM_CHECK(NOTRANS);
+ ENUM_CHECK(TRANS);
+ ENUM_CHECK(CONJ);
+ ENUM_CHECK_FINISH;
+}
+
+static int iterrefine_cvt(PyObject *input, IterRefine_t *value)
+{
+ ENUM_CHECK_INIT;
+ ENUM_CHECK(NOREFINE);
+ ENUM_CHECK(SINGLE);
+ ENUM_CHECK(DOUBLE);
+ ENUM_CHECK(EXTRA);
+ ENUM_CHECK_FINISH;
+}
+
+static int norm_cvt(PyObject *input, norm_t *value)
+{
+ ENUM_CHECK_INIT;
+ ENUM_CHECK(ONE_NORM);
+ ENUM_CHECK(TWO_NORM);
+ ENUM_CHECK(INF_NORM);
+ ENUM_CHECK_FINISH;
+}
+
+static int milu_cvt(PyObject *input, milu_t *value)
+{
+ ENUM_CHECK_INIT;
+ ENUM_CHECK(SILU);
+ ENUM_CHECK(SMILU_1);
+ ENUM_CHECK(SMILU_2);
+ ENUM_CHECK(SMILU_3);
+ ENUM_CHECK_FINISH;
+}
+
+int set_superlu_options_from_dict(superlu_options_t *options,
+ int ilu, PyObject *option_dict)
+{
+ PyObject *args;
+ int ret;
+
+ static char *kwlist[] = {
+ "Fact", "Equil", "ColPerm", "Trans", "IterRefine",
+ "DiagPivotThresh", "PivotGrowth", "ConditionNumber",
+ "RowPerm", "SymmetricMode", "PrintStat", "ReplaceTinyPivot",
+ "SolveInitialized", "RefineInitialized", "ILU_Norm",
+ "ILU_MILU", "ILU_DropTol", "ILU_FillTol", "ILU_FillFactor",
+ "ILU_DropRule", NULL
+ };
+
+ if (ilu) {
+ ilu_set_default_options(options);
+ }
+ else {
+ set_default_options(options);
+ }
+
+ if (option_dict == NULL) {
+ return 0;
+ }
+
+ args = PyTuple_New(0);
+ ret = PyArg_ParseTupleAndKeywords(
+ args, option_dict,
+ "|O&O&O&O&O&dO&O&O&O&O&O&O&O&O&O&dddi", kwlist,
+ fact_cvt, &options->Fact,
+ yes_no_cvt, &options->Equil,
+ colperm_cvt, &options->ColPerm,
+ trans_cvt, &options->Trans,
+ iterrefine_cvt, &options->IterRefine,
+ &options->DiagPivotThresh,
+ yes_no_cvt, &options->PivotGrowth,
+ yes_no_cvt, &options->ConditionNumber,
+ rowperm_cvt, &options->RowPerm,
+ yes_no_cvt, &options->SymmetricMode,
+ yes_no_cvt, &options->PrintStat,
+ yes_no_cvt, &options->ReplaceTinyPivot,
+ yes_no_cvt, &options->SolveInitialized,
+ yes_no_cvt, &options->RefineInitialized,
+ norm_cvt, &options->ILU_Norm,
+ milu_cvt, &options->ILU_MILU,
+ &options->ILU_DropTol,
+ &options->ILU_FillTol,
+ &options->ILU_FillFactor,
+ &options->ILU_DropRule
+ );
+ Py_DECREF(args);
+ return ret;
+}
Modified: trunk/scipy/sparse/linalg/dsolve/_superluobject.h
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/_superluobject.h 2010-04-27 21:56:00 UTC (rev 6349)
+++ trunk/scipy/sparse/linalg/dsolve/_superluobject.h 2010-04-27 21:56:19 UTC (rev 6350)
@@ -33,9 +33,10 @@
int NCFormat_from_spMatrix(SuperMatrix *, int, int, int, PyArrayObject *,
PyArrayObject *, PyArrayObject *, int);
colperm_t superlu_module_getpermc(int);
-PyObject *newSciPyLUObject(SuperMatrix *, double, int, int, int, int);
+PyObject *newSciPyLUObject(SuperMatrix *, int, int, PyObject*, int, int);
+int set_superlu_options_from_dict(superlu_options_t *options,
+ int ilu, PyObject *option_dict);
-
/*
* Definitions for other SuperLU data types than Z,
* and type-generic definitions.
@@ -93,7 +94,7 @@
SuperLUStat_t *h, int *i
#define gssv_ARGS_REF a,b,c,d,e,f,g,h,i
-#define Create_Dense_Matrix_ARGS \
+#define Create_Dense_Matrix_ARGS \
SuperMatrix *a, int b, int c, void *d, int e, \
Stype_t f, Dtype_t g, Mtype_t h
#define Create_Dense_Matrix_ARGS_REF a,b,c,d,e,f,g,h
Modified: trunk/scipy/sparse/linalg/dsolve/linsolve.py
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/linsolve.py 2010-04-27 21:56:00 UTC (rev 6349)
+++ trunk/scipy/sparse/linalg/dsolve/linsolve.py 2010-04-27 21:56:19 UTC (rev 6350)
@@ -90,7 +90,9 @@
flag = 0 # CSR format
b = asarray(b, dtype=A.dtype)
- return _superlu.gssv(N, A.nnz, A.data, A.indices, A.indptr, b, flag, permc_spec)[0]
+ options = dict(ColPerm=permc_spec)
+ return _superlu.gssv(N, A.nnz, A.data, A.indices, A.indptr, b, flag,
+ options=options)[0]
def splu(A, permc_spec=2, diag_pivot_thresh=1.0,
drop_tol=0.0, relax=1, panel_size=10):
@@ -114,8 +116,12 @@
if (M != N):
raise ValueError, "can only factor square matrices" #is this true?
- return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr, permc_spec,
- diag_pivot_thresh, drop_tol, relax, panel_size)
+ ilu = (drop_tol != 0)
+ options = dict(ILU_DropTol=drop_tol, DiagPivotThresh=diag_pivot_thresh,
+ ColPerm=permc_spec)
+ return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr,
+ relax=relax, panel_size=panel_size, ilu=ilu,
+ options=options)
def factorized( A ):
"""
More information about the Scipy-svn
mailing list