[Numpy-svn] r3321 - in trunk/numpy: core linalg
numpy-svn at scipy.org
numpy-svn at scipy.org
Thu Oct 12 20:41:46 EDT 2006
Author: oliphant
Date: 2006-10-12 19:41:44 -0500 (Thu, 12 Oct 2006)
New Revision: 3321
Modified:
trunk/numpy/core/numeric.py
trunk/numpy/linalg/linalg.py
Log:
Add solvetensor and invtensor
Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py 2006-10-12 19:29:12 UTC (rev 3320)
+++ trunk/numpy/core/numeric.py 2006-10-13 00:41:44 UTC (rev 3321)
@@ -252,7 +252,7 @@
pass
-def tensordot(a, b, axes=(-1,0)):
+def tensordot(a, b, axes=[-1,0]):
"""tensordot returns the product for any (ndim >= 1) arrays.
r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where
Modified: trunk/numpy/linalg/linalg.py
===================================================================
--- trunk/numpy/linalg/linalg.py 2006-10-12 19:29:12 UTC (rev 3320)
+++ trunk/numpy/linalg/linalg.py 2006-10-13 00:41:44 UTC (rev 3321)
@@ -6,7 +6,7 @@
zgeev, dgesdd, zgesdd, dgelsd, zgelsd, dsyevd, zheevd, dgetrf, dpotrf.
"""
-__all__ = ['solve',
+__all__ = ['solve', 'solvetensor', 'invtensor',
'inv', 'cholesky',
'eigvals',
'eigvalsh', 'pinv',
@@ -21,7 +21,7 @@
newaxis, ravel, all, Inf, dot, add, multiply, identity, sqrt, \
maximum, flatnonzero, diagonal, arange, fastCopyAndTranspose, sum, \
argsort
-from numpy.lib import triu
+from numpy.lib import triu, iterable
from numpy.linalg import lapack_lite
fortran_int = intc
@@ -122,6 +122,36 @@
# Linear equations
+def solvetensor(a, b, axes=None):
+ """Solves the tensor equation a x = b for x
+
+ where it is assumed that all the indices of x are summed over in the product.
+
+ a can be N-dimensional. x will have the dimensions of A subtracted from
+ the dimensions of b.
+ """
+ a = asarray(a)
+ b = asarray(b)
+ an = a.ndim
+
+ if axes is not None:
+ allaxes = range(0,an)
+ for k in axes:
+ allaxes.remove(k)
+ allaxes.insert(an, k)
+ a = a.transpose(allaxes)
+
+ oldshape = a.shape[-(an-b.ndim):]
+ prod = 1
+ for k in oldshape:
+ prod *= k
+
+ a = a.reshape(-1,prod)
+ b = b.ravel()
+ res = solve(a, b)
+ res.shape = oldshape
+ return res
+
def solve(a, b):
"""Return the solution of a*x = b
"""
@@ -151,6 +181,44 @@
return b.transpose().astype(result_t)
+def invtensor(a, ind=2):
+ """Find the inverse tensor.
+
+ ind > 0 ==> first (ind) indices of a are summed over
+ ind < 0 ==> last (-ind) indices of a are summed over
+
+ if ind is a list, then it specifies the summed over axes
+
+ When the inv tensor and the tensor are summed over the
+ indicated axes a separable identity tensor remains.
+
+ The inverse has the summed over axes at the end.
+ """
+
+ a = asarray(a)
+ oldshape = a.shape
+ prod = 1
+ if iterable(ind):
+ invshape = range(a.ndim)
+ for axis in ind:
+ invshape.remove(axis)
+ invshape.insert(a.ndim,axis)
+ prod *= oldshape[axis]
+ elif ind > 0:
+ invshape = oldshape[ind:] + oldshape[:ind]
+ for k in oldshape[:ind]:
+ prod *= k
+ elif ind < 0:
+ invshape = oldshape[:-ind] + oldshape[-ind:]
+ for k in oldshape[-ind:]:
+ prod *= k
+ else:
+ raise ValueError, "Invalid ind argument."
+ a = a.reshape(-1,prod)
+ ia = inv(a)
+ return ia.reshape(*invshape)
+
+
# Matrix inversion
def inv(a):
More information about the Numpy-svn
mailing list