[Numpy-svn] r3321 - in trunk/numpy: core linalg

numpy-svn at scipy.org numpy-svn at scipy.org
Thu Oct 12 20:41:46 EDT 2006


Author: oliphant
Date: 2006-10-12 19:41:44 -0500 (Thu, 12 Oct 2006)
New Revision: 3321

Modified:
   trunk/numpy/core/numeric.py
   trunk/numpy/linalg/linalg.py
Log:
Add solvetensor and invtensor

Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2006-10-12 19:29:12 UTC (rev 3320)
+++ trunk/numpy/core/numeric.py	2006-10-13 00:41:44 UTC (rev 3321)
@@ -252,7 +252,7 @@
         pass
 
 
-def tensordot(a, b, axes=(-1,0)):
+def tensordot(a, b, axes=[-1,0]):
     """tensordot returns the product for any (ndim >= 1) arrays.
 
     r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where

Modified: trunk/numpy/linalg/linalg.py
===================================================================
--- trunk/numpy/linalg/linalg.py	2006-10-12 19:29:12 UTC (rev 3320)
+++ trunk/numpy/linalg/linalg.py	2006-10-13 00:41:44 UTC (rev 3321)
@@ -6,7 +6,7 @@
 zgeev, dgesdd, zgesdd, dgelsd, zgelsd, dsyevd, zheevd, dgetrf, dpotrf.
 """
 
-__all__ = ['solve',
+__all__ = ['solve', 'solvetensor', 'invtensor',
            'inv', 'cholesky',
            'eigvals',
            'eigvalsh', 'pinv',
@@ -21,7 +21,7 @@
         newaxis, ravel, all, Inf, dot, add, multiply, identity, sqrt, \
         maximum, flatnonzero, diagonal, arange, fastCopyAndTranspose, sum, \
         argsort
-from numpy.lib import triu
+from numpy.lib import triu, iterable
 from numpy.linalg import lapack_lite
 
 fortran_int = intc
@@ -122,6 +122,36 @@
 
 # Linear equations
 
+def solvetensor(a, b, axes=None):
+    """Solves the tensor equation a x = b for x
+
+    where it is assumed that all the indices of x are summed over in the product.
+
+    a can be N-dimensional.  x will have the dimensions of A subtracted from
+    the dimensions of b.
+    """
+    a = asarray(a)
+    b = asarray(b)
+    an = a.ndim
+
+    if axes is not None:
+        allaxes = range(0,an)
+        for k in axes:
+            allaxes.remove(k)
+            allaxes.insert(an, k)
+        a = a.transpose(allaxes)
+        
+    oldshape = a.shape[-(an-b.ndim):]
+    prod = 1
+    for k in oldshape:
+        prod *= k
+    
+    a = a.reshape(-1,prod)
+    b = b.ravel()
+    res = solve(a, b)
+    res.shape = oldshape
+    return res
+
 def solve(a, b):
     """Return the solution of a*x = b
     """
@@ -151,6 +181,44 @@
         return b.transpose().astype(result_t)
 
 
+def invtensor(a, ind=2):
+    """Find the inverse tensor.
+
+    ind > 0 ==> first (ind) indices of a are summed over 
+    ind < 0 ==> last (-ind) indices of a are summed over
+
+    if ind is a list, then it specifies the summed over axes
+
+    When the inv tensor and the tensor are summed over the
+    indicated axes a separable identity tensor remains. 
+
+    The inverse has the summed over axes at the end.
+    """
+    
+    a = asarray(a)
+    oldshape = a.shape
+    prod = 1
+    if iterable(ind):
+        invshape = range(a.ndim)
+        for axis in ind:
+            invshape.remove(axis)
+            invshape.insert(a.ndim,axis)
+            prod *= oldshape[axis]
+    elif ind > 0:
+        invshape = oldshape[ind:] + oldshape[:ind]
+        for k in oldshape[:ind]:
+            prod *= k
+    elif ind < 0:
+        invshape = oldshape[:-ind] + oldshape[-ind:]
+        for k in oldshape[-ind:]:
+            prod *= k
+    else:
+        raise ValueError, "Invalid ind argument."
+    a = a.reshape(-1,prod)
+    ia = inv(a)
+    return ia.reshape(*invshape)
+    
+
 # Matrix inversion
 
 def inv(a):




More information about the Numpy-svn mailing list