[Scipy-svn] r6350 - trunk/scipy/sparse/linalg/dsolve

scipy-svn at scipy.org scipy-svn at scipy.org
Tue Apr 27 17:56:20 EDT 2010


Author: ptvirtan
Date: 2010-04-27 16:56:19 -0500 (Tue, 27 Apr 2010)
New Revision: 6350

Modified:
   trunk/scipy/sparse/linalg/dsolve/_superlumodule.c
   trunk/scipy/sparse/linalg/dsolve/_superluobject.c
   trunk/scipy/sparse/linalg/dsolve/_superluobject.h
   trunk/scipy/sparse/linalg/dsolve/linsolve.py
Log:
ENH: sparse.linalg.dsolve: allow passing any SuperLU options to the factorized in the internal gstrf interface. Enable support for incomplete LU

Modified: trunk/scipy/sparse/linalg/dsolve/_superlumodule.c
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/_superlumodule.c	2010-04-27 21:56:00 UTC (rev 6349)
+++ trunk/scipy/sparse/linalg/dsolve/_superlumodule.c	2010-04-27 21:56:19 UTC (rev 6350)
@@ -33,21 +33,22 @@
     PyArrayObject *colind=NULL, *rowptr=NULL;
     int N, nnz;
     int info;
-    int csc=0, permc_spec=2;
+    int csc=0;
     int *perm_r=NULL, *perm_c=NULL;
     SuperMatrix A, B, L, U;
     superlu_options_t options;
     SuperLUStat_t stat;
+    PyObject *option_dict = NULL;
     int type;
 
     static char *kwlist[] = {"N","nnz","nzvals","colind","rowptr","B", "csc",
-                             "permc_spec",NULL};
+                             "options",NULL};
     
     /* Get input arguments */
-    if (!PyArg_ParseTupleAndKeywords(args, kwdict, "iiO!O!O!O|ii", kwlist,
+    if (!PyArg_ParseTupleAndKeywords(args, kwdict, "iiO!O!O!O|iO", kwlist,
                                      &N, &nnz, &PyArray_Type, &nzvals,
                                      &PyArray_Type, &colind, &PyArray_Type,
-                                     &rowptr, &Py_B, &csc, &permc_spec)) {
+                                     &rowptr, &Py_B, &csc, &option_dict)) {
         return NULL;
     }
 
@@ -64,6 +65,10 @@
         return NULL;
     }
 
+    if (!set_superlu_options_from_dict(&options, 0, option_dict)) {
+        return NULL;
+    }
+
     /* Create Space for output */
     Py_X = PyArray_CopyFromObject(Py_B, type, 1, 2);
     if (Py_X == NULL) return NULL;
@@ -99,8 +104,6 @@
     else {
         perm_c = intMalloc(N);
         perm_r = intMalloc(N);
-        set_default_options(&options);
-        options.ColPerm = superlu_module_getpermc(permc_spec);
         StatInit(&stat);
 
         /* Compute direct inverse of sparse Matrix */
@@ -133,30 +136,30 @@
 Py_gstrf(PyObject *self, PyObject *args, PyObject *keywds)
 {
     /* default value for SuperLU parameters*/
-    double diag_pivot_thresh = 1.0;
     int relax = 1;
     int panel_size = 10;
-    int permc_spec = 2;
     int N, nnz;
     PyArrayObject *rowind, *colptr, *nzvals;
     SuperMatrix A;
     PyObject *result;
+    PyObject *option_dict = NULL;
     int type;
-  
+    int ilu = 0;
+
     static char *kwlist[] = {"N","nnz","nzvals","rowind","colptr",
-                             "permc_spec","diag_pivot_thresh",
-                             "relax", "panel_size", NULL};
+                             "options", "relax", "panel_size", "ilu",
+                             NULL};
 
     int res = PyArg_ParseTupleAndKeywords(
-        args, keywds, "iiO!O!O!|iddii", kwlist, 
+        args, keywds, "iiO!O!O!|Oiii", kwlist, 
         &N, &nnz,
         &PyArray_Type, &nzvals,
         &PyArray_Type, &rowind,
         &PyArray_Type, &colptr,
-        &permc_spec,
-        &diag_pivot_thresh,
+        &option_dict,
         &relax,
-        &panel_size);
+        &panel_size,
+        &ilu);
 
     if (!res)
         return NULL;
@@ -179,8 +182,8 @@
         goto fail;
     }
 
-    result = newSciPyLUObject(&A, diag_pivot_thresh, relax,
-                              panel_size, permc_spec, type);
+    result = newSciPyLUObject(&A, relax,
+                              panel_size, option_dict, type, ilu);
     if (result == NULL) {
         goto fail;
     }
@@ -217,24 +220,17 @@
 \n\
 additional keyword arguments:\n\
 -----------------------------\n\
-permc_spec          specifies the matrix ordering used for the factorization\n\
-                    0: natural ordering\n\
-                    1: MMD applied to the structure of A^T * A\n\
-                    2: MMD applied to the structure of A^T + A\n\
-                    3: COLAMD, approximate minimum degree column ordering\n\
-                    (default: 2)\n\
+options             specifies additional options for SuperLU\n\
+                    (same keys and values as in superlu_options_t C structure)\n\
 \n\
-diag_pivot_thresh   threshhold for partial pivoting.\n\
-                    0.0 <= diag_pivot_thresh <= 1.0\n\
-                    0.0 corresponds to no pivoting\n\
-                    1.0 corresponds to partial pivoting\n\
-                    (default: 1.0)\n\
-\n\
 relax               to control degree of relaxing supernodes\n\
                     (default: 1)\n\
 \n\
 panel_size          a panel consist of at most panel_size consecutive columns.\n\
                     (default: 10)\n\
+\n\
+ilu                 whether to perform an incomplete LU decomposition\n\
+                    (default: false)\n\
 ";
 
 

Modified: trunk/scipy/sparse/linalg/dsolve/_superluobject.c
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/_superluobject.c	2010-04-27 21:56:00 UTC (rev 6349)
+++ trunk/scipy/sparse/linalg/dsolve/_superluobject.c	2010-04-27 21:56:19 UTC (rev 6350)
@@ -6,6 +6,7 @@
 
 extern jmp_buf _superlu_py_jmpbuf;
 
+
 /*********************************************************************** 
  * SciPyLUObject methods
  */
@@ -268,26 +269,9 @@
   return 0;
 }
 
-colperm_t superlu_module_getpermc(int permc_spec)
-{
-  switch(permc_spec) {
-  case 0:
-    return NATURAL;
-  case 1:
-    return MMD_ATA;
-  case 2:
-    return MMD_AT_PLUS_A;
-  case 3:
-    return COLAMD;
-  }
-  ABORT("Invalid input for permc_spec.");
-  return NATURAL; /* compiler complains... */
-}
-
 PyObject *
-newSciPyLUObject(SuperMatrix *A, double diag_pivot_thresh,
-		 int relax, int panel_size, int permc_spec,
-                 int intype)
+newSciPyLUObject(SuperMatrix *A, int relax, int panel_size,
+                 PyObject *option_dict, int intype, int ilu)
 {
 
    /* A must be in SLU_NC format used by the factorization routine. */
@@ -302,6 +286,10 @@
   
   n = A->ncol;
 
+  if (!set_superlu_options_from_dict(&options, ilu, option_dict)) {
+      return NULL;
+  }
+
   /* Create SciPyLUObject */
   self = PyObject_New(SciPyLUObject, &SciPySuperLUType);
   if (self == NULL)
@@ -318,13 +306,9 @@
   etree = intMalloc(n);
   self->perm_r = intMalloc(n);
   self->perm_c = intMalloc(n);
+  StatInit(&stat);
 
-  set_default_options(&options);
-  options.ColPerm=superlu_module_getpermc(permc_spec);
-  options.DiagPivotThresh = diag_pivot_thresh;
-  StatInit(&stat);
-  
-  get_perm_c(permc_spec, A, self->perm_c); /* calc column permutation */
+  get_perm_c(options.ColPerm, A, self->perm_c); /* calc column permutation */
   sp_preorder(&options, A, self->perm_c, etree, &AC); /* apply column
                                                        * permutation */
 
@@ -333,10 +317,18 @@
       PyErr_SetString(PyExc_ValueError, "Invalid type in SuperMatrix.");
       goto fail;
   }
-  gstrf(SLU_TYPECODE_TO_NPY(A->Dtype),
-        &options, &AC, relax, panel_size,
-        etree, NULL, lwork, self->perm_c, self->perm_r,
-        &self->L, &self->U, &stat, &info);
+  if (ilu) {
+      gsitrf(SLU_TYPECODE_TO_NPY(A->Dtype),
+             &options, &AC, relax, panel_size,
+             etree, NULL, lwork, self->perm_c, self->perm_r,
+             &self->L, &self->U, &stat, &info);
+  }
+  else {
+      gstrf(SLU_TYPECODE_TO_NPY(A->Dtype),
+            &options, &AC, relax, panel_size,
+            etree, NULL, lwork, self->perm_c, self->perm_r,
+            &self->L, &self->U, &stat, &info);
+  }
 
   if (info) {
     if (info < 0)
@@ -365,3 +357,161 @@
   SciPyLU_dealloc(self);
   return NULL;
 }
+
+
+/***********************************************************************
+ * Preparing superlu_options_t
+ */
+
+#define ENUM_CHECK_INIT                         \
+    long i = -1;                                \
+    char *s = "";                               \
+    if (PyString_Check(input)) {                \
+        s = PyString_AS_STRING(input);          \
+    }                                           \
+    if (PyInt_Check(input)) {                   \
+        i = PyInt_AsLong(input);                \
+    }
+
+#define ENUM_CHECK_FINISH                               \
+    PyErr_SetString(PyExc_ValueError, "unknown value"); \
+    return 0;
+
+#define ENUM_CHECK(name) \
+    if (strcmp(s, #name) == 0 || i == (long)name) { *value = name; return 1; }
+
+static int yes_no_cvt(PyObject *input, yes_no_t *value)
+{
+    if (input == Py_True) {
+        *value = YES;
+    } else if (input == Py_False) {
+        *value = NO;
+    } else {
+        PyErr_SetString(PyExc_ValueError, "value not a boolean");
+        return 0;
+    }
+    return 1;
+}
+
+static int fact_cvt(PyObject *input, fact_t *value)
+{
+    ENUM_CHECK_INIT;
+    ENUM_CHECK(DOFACT);
+    ENUM_CHECK(SamePattern);
+    ENUM_CHECK(SamePattern_SameRowPerm);
+    ENUM_CHECK(FACTORED);
+    ENUM_CHECK_FINISH;
+}
+
+static int rowperm_cvt(PyObject *input, rowperm_t *value)
+{
+    ENUM_CHECK_INIT;
+    ENUM_CHECK(NOROWPERM);
+    ENUM_CHECK(LargeDiag);
+    ENUM_CHECK(MY_PERMR);
+    ENUM_CHECK_FINISH;
+}
+
+static int colperm_cvt(PyObject *input, colperm_t *value)
+{
+    ENUM_CHECK_INIT;
+    ENUM_CHECK(NATURAL);
+    ENUM_CHECK(MMD_ATA);
+    ENUM_CHECK(MMD_AT_PLUS_A);
+    ENUM_CHECK(COLAMD);
+    ENUM_CHECK(MY_PERMC);
+    ENUM_CHECK_FINISH;
+}
+
+static int trans_cvt(PyObject *input, trans_t *value)
+{
+    ENUM_CHECK_INIT;
+    ENUM_CHECK(NOTRANS);
+    ENUM_CHECK(TRANS);
+    ENUM_CHECK(CONJ);
+    ENUM_CHECK_FINISH;
+}
+
+static int iterrefine_cvt(PyObject *input, IterRefine_t *value)
+{
+    ENUM_CHECK_INIT;
+    ENUM_CHECK(NOREFINE);
+    ENUM_CHECK(SINGLE);
+    ENUM_CHECK(DOUBLE);
+    ENUM_CHECK(EXTRA);
+    ENUM_CHECK_FINISH;
+}
+
+static int norm_cvt(PyObject *input, norm_t *value)
+{
+    ENUM_CHECK_INIT;
+    ENUM_CHECK(ONE_NORM);
+    ENUM_CHECK(TWO_NORM);
+    ENUM_CHECK(INF_NORM);
+    ENUM_CHECK_FINISH;
+}
+
+static int milu_cvt(PyObject *input, milu_t *value)
+{
+    ENUM_CHECK_INIT;
+    ENUM_CHECK(SILU);
+    ENUM_CHECK(SMILU_1);
+    ENUM_CHECK(SMILU_2);
+    ENUM_CHECK(SMILU_3);
+    ENUM_CHECK_FINISH;
+}
+
+int set_superlu_options_from_dict(superlu_options_t *options,
+                                  int ilu, PyObject *option_dict)
+{
+    PyObject *args;
+    int ret;
+
+    static char *kwlist[] = {
+        "Fact", "Equil", "ColPerm", "Trans", "IterRefine",
+        "DiagPivotThresh", "PivotGrowth", "ConditionNumber",
+        "RowPerm", "SymmetricMode", "PrintStat", "ReplaceTinyPivot",
+        "SolveInitialized", "RefineInitialized", "ILU_Norm",
+        "ILU_MILU", "ILU_DropTol", "ILU_FillTol", "ILU_FillFactor",
+        "ILU_DropRule", NULL
+    };
+
+    if (ilu) {
+        ilu_set_default_options(options);
+    }
+    else {
+        set_default_options(options);
+    }
+
+    if (option_dict == NULL) {
+        return 0;
+    }
+    
+    args = PyTuple_New(0);
+    ret = PyArg_ParseTupleAndKeywords(
+        args, option_dict,
+        "|O&O&O&O&O&dO&O&O&O&O&O&O&O&O&O&dddi", kwlist,
+        fact_cvt, &options->Fact,
+        yes_no_cvt, &options->Equil,
+        colperm_cvt, &options->ColPerm,
+        trans_cvt, &options->Trans,
+        iterrefine_cvt, &options->IterRefine,
+        &options->DiagPivotThresh,
+        yes_no_cvt, &options->PivotGrowth,
+        yes_no_cvt, &options->ConditionNumber,
+        rowperm_cvt, &options->RowPerm,
+        yes_no_cvt, &options->SymmetricMode,
+        yes_no_cvt, &options->PrintStat,
+        yes_no_cvt, &options->ReplaceTinyPivot,
+        yes_no_cvt, &options->SolveInitialized,
+        yes_no_cvt, &options->RefineInitialized,
+        norm_cvt, &options->ILU_Norm,
+        milu_cvt, &options->ILU_MILU,
+        &options->ILU_DropTol,
+        &options->ILU_FillTol,
+        &options->ILU_FillFactor,
+        &options->ILU_DropRule
+        );
+    Py_DECREF(args);
+    return ret;
+}

Modified: trunk/scipy/sparse/linalg/dsolve/_superluobject.h
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/_superluobject.h	2010-04-27 21:56:00 UTC (rev 6349)
+++ trunk/scipy/sparse/linalg/dsolve/_superluobject.h	2010-04-27 21:56:19 UTC (rev 6350)
@@ -33,9 +33,10 @@
 int NCFormat_from_spMatrix(SuperMatrix *, int, int, int, PyArrayObject *,
                            PyArrayObject *, PyArrayObject *, int);
 colperm_t superlu_module_getpermc(int);
-PyObject *newSciPyLUObject(SuperMatrix *, double, int, int, int, int);
+PyObject *newSciPyLUObject(SuperMatrix *, int, int, PyObject*, int, int);
+int set_superlu_options_from_dict(superlu_options_t *options,
+                                  int ilu, PyObject *option_dict);
 
-
 /*
  * Definitions for other SuperLU data types than Z,
  * and type-generic definitions.
@@ -93,7 +94,7 @@
     SuperLUStat_t *h, int *i
 #define gssv_ARGS_REF a,b,c,d,e,f,g,h,i
 
-#define Create_Dense_Matrix_ARGS                               \ 
+#define Create_Dense_Matrix_ARGS                               \
     SuperMatrix *a, int b, int c, void *d, int e,              \
     Stype_t f, Dtype_t g, Mtype_t h
 #define Create_Dense_Matrix_ARGS_REF a,b,c,d,e,f,g,h

Modified: trunk/scipy/sparse/linalg/dsolve/linsolve.py
===================================================================
--- trunk/scipy/sparse/linalg/dsolve/linsolve.py	2010-04-27 21:56:00 UTC (rev 6349)
+++ trunk/scipy/sparse/linalg/dsolve/linsolve.py	2010-04-27 21:56:19 UTC (rev 6350)
@@ -90,7 +90,9 @@
             flag = 0 # CSR format
 
         b = asarray(b, dtype=A.dtype)
-        return _superlu.gssv(N, A.nnz, A.data, A.indices, A.indptr, b, flag, permc_spec)[0]
+        options = dict(ColPerm=permc_spec)
+        return _superlu.gssv(N, A.nnz, A.data, A.indices, A.indptr, b, flag,
+                             options=options)[0]
 
 def splu(A, permc_spec=2, diag_pivot_thresh=1.0,
          drop_tol=0.0, relax=1, panel_size=10):
@@ -114,8 +116,12 @@
     if (M != N):
         raise ValueError, "can only factor square matrices" #is this true?
 
-    return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr, permc_spec,
-                          diag_pivot_thresh, drop_tol, relax, panel_size)
+    ilu = (drop_tol != 0)
+    options = dict(ILU_DropTol=drop_tol, DiagPivotThresh=diag_pivot_thresh,
+                   ColPerm=permc_spec)
+    return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr,
+                          relax=relax, panel_size=panel_size, ilu=ilu,
+                          options=options)
 
 def factorized( A ):
     """




More information about the Scipy-svn mailing list