[Scipy-svn] r3886 - in trunk/scipy: sandbox/multigrid sparse
scipy-svn at scipy.org
scipy-svn at scipy.org
Fri Feb 1 14:41:43 EST 2008
Author: wnbell
Date: 2008-02-01 13:41:26 -0600 (Fri, 01 Feb 2008)
New Revision: 3886
Modified:
trunk/scipy/sandbox/multigrid/sa.py
trunk/scipy/sparse/spfuncs.py
Log:
accept options to SA solver
Modified: trunk/scipy/sandbox/multigrid/sa.py
===================================================================
--- trunk/scipy/sandbox/multigrid/sa.py 2008-02-01 03:13:28 UTC (rev 3885)
+++ trunk/scipy/sandbox/multigrid/sa.py 2008-02-01 19:41:26 UTC (rev 3886)
@@ -1,3 +1,5 @@
+"""Functions for Smoothed Aggregation AMG"""
+
from numpy import array, arange, ones, zeros, sqrt, asarray, empty, diff
from scipy.sparse import csr_matrix, isspmatrix_csr, bsr_matrix, isspmatrix_bsr
@@ -24,6 +26,7 @@
return A
if isspmatrix_csr(A):
+ #TODO rework this
Sp,Sj,Sx = multigridtools.sa_strong_connections(A.shape[0],epsilon,A.indptr,A.indices,A.data)
return csr_matrix((Sx,Sj,Sp),shape=A.shape)
elif ispmatrix_bsr(A):
@@ -198,14 +201,51 @@
return P
+def sa_prolongator(A, B, strength='standard', aggregate='standard', smooth='standard'):
-def smoothed_aggregation_solver(A, B=None,
- max_levels = 10,
- max_coarse = 500,
- strength = sa_strong_connections,
- aggregate = sa_standard_aggregation,
- tentative = sa_fit_candidates,
- smooth = sa_smoothed_prolongator):
+ def unpack_arg(v):
+ if isinstance(v,tuple):
+ return v[0],v[1]
+ else:
+ return v,{}
+
+ # strength of connection
+ fn, kwargs = unpack_arg(strength)
+ if fn == 'standard':
+ C = sa_strong_connections(A,**kwargs)
+ elif fn == 'ode':
+ C = sa_ode_strong_connections(A,B,**kwargs)
+ else:
+ raise ValueError('unrecognized strength of connection method: %s' % fn)
+
+ # aggregation
+ fn, kwargs = unpack_arg(aggregate)
+ if fn == 'standard':
+ AggOp = sa_standard_aggregation(C,**kwargs)
+ else:
+ raise ValueError('unrecognized aggregation method' % fn )
+
+ # tentative prolongator
+ T,B = sa_fit_candidates(AggOp,B)
+
+ # tentative prolongator smoother
+ fn, kwargs = unpack_arg(smooth)
+ if fn == 'standard':
+ P = sa_smoothed_prolongator(A,T,**kwargs)
+ elif fn == 'energy_min':
+ P = sa_energy_min(A,T,C,B,**kwargs)
+ else:
+ raise ValueError('unrecognized prolongation smoother method % ' % fn)
+
+ return P,B
+
+
+
+
+
+
+def smoothed_aggregation_solver(A, B=None, max_levels = 10, max_coarse = 500,
+ solver = multilevel_solver, **kwargs):
"""Create a multilevel solver using Smoothed Aggregation (SA)
*Parameters*:
@@ -219,19 +259,23 @@
Maximum number of levels to be used in the multilevel solver.
max_coarse: {integer} : default 500
Maximum number of variables permitted on the coarse grid.
- strength :
- Function that computes the strength of connection matrix C
- strength(A) -> C
- aggregate :
- Function that computes an aggregation operator
- aggregate(C) -> AggOp
- tentative:
- Function that computes a tentative prolongator
- tentative(AggOp,B) -> T,B_coarse
- smooth :
- Function that smooths the tentative prolongator
- smooth(A,C,T) -> P
+
+ *Optional Parameters*:
+ strength : strength of connection method
+ Possible values are:
+ 'standard'
+ 'ode'
+
+ aggregate : aggregation method
+ Possible values are:
+ 'standard'
+
+ smooth : prolongation smoother
+ Possible values are:
+ 'standard'
+ 'energy_min'
+
Unused Parameters
epsilon: {float} : default 0.0
Strength of connection parameter used in aggregation.
@@ -287,10 +331,7 @@
Rs = []
while len(As) < max_levels and A.shape[0] > max_coarse:
- C = strength(A)
- AggOp = aggregate(C)
- T,B = tentative(AggOp,B)
- P = smooth(A,T)
+ P,B = sa_prolongator(A,B,**kwargs)
R = P.T.asformat(P.format)
@@ -301,6 +342,6 @@
Ps.append(P)
- return multilevel_solver(As,Ps,Rs=Rs,preprocess=pre,postprocess=post)
+ return solver(As,Ps,Rs=Rs,preprocess=pre,postprocess=post)
Modified: trunk/scipy/sparse/spfuncs.py
===================================================================
--- trunk/scipy/sparse/spfuncs.py 2008-02-01 03:13:28 UTC (rev 3885)
+++ trunk/scipy/sparse/spfuncs.py 2008-02-01 19:41:26 UTC (rev 3886)
@@ -88,7 +88,7 @@
def count_blocks(A,blocksize):
"""For a given blocksize=(r,c) count the number of occupied
- blocks in a sparse matrix A using
+ blocks in a sparse matrix A
"""
r,c = blocksize
if r < 1 or c < 1:
More information about the Scipy-svn
mailing list