[Scipy-svn] r3886 - in trunk/scipy: sandbox/multigrid sparse

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Feb 1 14:41:43 EST 2008


Author: wnbell
Date: 2008-02-01 13:41:26 -0600 (Fri, 01 Feb 2008)
New Revision: 3886

Modified:
   trunk/scipy/sandbox/multigrid/sa.py
   trunk/scipy/sparse/spfuncs.py
Log:
accept options to SA solver


Modified: trunk/scipy/sandbox/multigrid/sa.py
===================================================================
--- trunk/scipy/sandbox/multigrid/sa.py	2008-02-01 03:13:28 UTC (rev 3885)
+++ trunk/scipy/sandbox/multigrid/sa.py	2008-02-01 19:41:26 UTC (rev 3886)
@@ -1,3 +1,5 @@
+"""Functions for Smoothed Aggregation AMG"""
+
 from numpy import array, arange, ones, zeros, sqrt, asarray, empty, diff
 from scipy.sparse import csr_matrix, isspmatrix_csr, bsr_matrix, isspmatrix_bsr
 
@@ -24,6 +26,7 @@
         return A
 
     if isspmatrix_csr(A): 
+        #TODO rework this
         Sp,Sj,Sx = multigridtools.sa_strong_connections(A.shape[0],epsilon,A.indptr,A.indices,A.data)
         return csr_matrix((Sx,Sj,Sp),shape=A.shape)
     elif ispmatrix_bsr(A):
@@ -198,14 +201,51 @@
     return P
 
 
+def sa_prolongator(A, B, strength='standard', aggregate='standard', smooth='standard'):
 
-def smoothed_aggregation_solver(A, B=None, 
-        max_levels = 10, 
-        max_coarse = 500,
-        strength   = sa_strong_connections, 
-        aggregate  = sa_standard_aggregation,
-        tentative  = sa_fit_candidates,
-        smooth     = sa_smoothed_prolongator):
+    def unpack_arg(v):
+        if isinstance(v,tuple):
+            return v[0],v[1]
+        else:
+            return v,{}
+
+    # strength of connection
+    fn, kwargs = unpack_arg(strength)
+    if fn == 'standard':
+        C = sa_strong_connections(A,**kwargs)
+    elif fn == 'ode':
+        C = sa_ode_strong_connections(A,B,**kwargs)
+    else:
+        raise ValueError('unrecognized strength of connection method: %s' % fn)
+
+    # aggregation
+    fn, kwargs = unpack_arg(aggregate)
+    if fn == 'standard':
+        AggOp = sa_standard_aggregation(C,**kwargs)
+    else:
+        raise ValueError('unrecognized aggregation method' % fn )
+
+    # tentative prolongator
+    T,B = sa_fit_candidates(AggOp,B)
+
+    # tentative prolongator smoother
+    fn, kwargs = unpack_arg(smooth)
+    if fn == 'standard':
+        P = sa_smoothed_prolongator(A,T,**kwargs)
+    elif fn == 'energy_min':
+        P = sa_energy_min(A,T,C,B,**kwargs)
+    else:
+        raise ValueError('unrecognized prolongation smoother method % ' % fn)
+    
+    return P,B
+
+
+
+
+
+
+def smoothed_aggregation_solver(A, B=None, max_levels = 10, max_coarse = 500,
+                                solver = multilevel_solver, **kwargs):
     """Create a multilevel solver using Smoothed Aggregation (SA)
 
     *Parameters*:
@@ -219,19 +259,23 @@
             Maximum number of levels to be used in the multilevel solver.
         max_coarse: {integer} : default 500
             Maximum number of variables permitted on the coarse grid.
-        strength :
-            Function that computes the strength of connection matrix C
-                strength(A) -> C
-        aggregate : 
-            Function that computes an aggregation operator
-                aggregate(C) -> AggOp
-        tentative:
-            Function that computes a tentative prolongator
-                tentative(AggOp,B) -> T,B_coarse
-        smooth :
-            Function that smooths the tentative prolongator
-                smooth(A,C,T) -> P
+    
+    *Optional Parameters*:
+        strength : strength of connection method
+            Possible values are:
+                'standard' 
+                'ode'
+        
+        aggregate : aggregation method
+            Possible values are:
+                'standard'
+        
+        smooth : prolongation smoother
+            Possible values are:
+                'standard'
+                'energy_min'
 
+
     Unused Parameters
         epsilon: {float} : default 0.0
             Strength of connection parameter used in aggregation.
@@ -287,10 +331,7 @@
     Rs = []
 
     while len(As) < max_levels and A.shape[0] > max_coarse:
-        C     = strength(A)
-        AggOp = aggregate(C)
-        T,B   = tentative(AggOp,B)
-        P     = smooth(A,T)
+        P,B = sa_prolongator(A,B,**kwargs)
 
         R = P.T.asformat(P.format)
 
@@ -301,6 +342,6 @@
         Ps.append(P)
 
 
-    return multilevel_solver(As,Ps,Rs=Rs,preprocess=pre,postprocess=post)
+    return solver(As,Ps,Rs=Rs,preprocess=pre,postprocess=post)
 
 

Modified: trunk/scipy/sparse/spfuncs.py
===================================================================
--- trunk/scipy/sparse/spfuncs.py	2008-02-01 03:13:28 UTC (rev 3885)
+++ trunk/scipy/sparse/spfuncs.py	2008-02-01 19:41:26 UTC (rev 3886)
@@ -88,7 +88,7 @@
 
 def count_blocks(A,blocksize):
     """For a given blocksize=(r,c) count the number of occupied 
-    blocks in a sparse matrix A using 
+    blocks in a sparse matrix A
     """
     r,c = blocksize
     if r < 1 or c < 1:




More information about the Scipy-svn mailing list