[Numpy-svn] r3293 - in trunk/numpy/lib: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Mon Oct 9 03:47:10 EDT 2006
Author: oliphant
Date: 2006-10-09 02:47:06 -0500 (Mon, 09 Oct 2006)
New Revision: 3293
Modified:
trunk/numpy/lib/shape_base.py
trunk/numpy/lib/tests/test_shape_base.py
Log:
Fix kron for multiple-dimensions. kron is defined so tile(b, s) is the same as kron(ones(s,b.dtype), b)
Modified: trunk/numpy/lib/shape_base.py
===================================================================
--- trunk/numpy/lib/shape_base.py 2006-10-08 13:16:13 UTC (rev 3292)
+++ trunk/numpy/lib/shape_base.py 2006-10-09 07:47:06 UTC (rev 3293)
@@ -1,7 +1,7 @@
__all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack',
'column_stack','row_stack', 'dstack','array_split','split','hsplit',
'vsplit','dsplit','apply_over_axes','expand_dims',
- 'apply_along_axis', 'kron', 'tile']
+ 'apply_along_axis', 'kron', 'tile', 'get_array_wrap']
import numpy.core.numeric as _nx
from numpy.core.numeric import asarray, zeros, newaxis, outer, \
@@ -526,7 +526,7 @@
raise ValueError, 'vsplit only works on arrays of 3 or more dimensions'
return split(ary,indices_or_sections,2)
-def _getwrapper(*args):
+def get_array_wrap(*args):
"""Find the wrapper for the array with the highest priority.
In case of ties, leftmost wins. If no wrapper is found, return None
@@ -547,19 +547,28 @@
[ ... ... ],
[ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]]
"""
- wrapper = _getwrapper(a, b)
+ wrapper = get_array_wrap(a, b)
b = asanyarray(b)
a = array(a,copy=False,subok=True,ndmin=b.ndim)
+ ndb, nda = b.ndim, a.ndim
+ if (nda == 0 or ndb == 0):
+ return a * b
as = a.shape
bs = b.shape
if not a.flags.contiguous:
a = reshape(a, as)
if not b.flags.contiguous:
b = reshape(b, bs)
- o = outer(a,b)
- result = o.reshape(as + bs)
- axis = a.ndim-1
- for k in xrange(b.ndim):
+ nd = ndb
+ if (ndb != nda):
+ if (ndb > nda):
+ as = (1,)*(ndb-nda) + as
+ else:
+ bs = (1,)*(nda-ndb) + bs
+ nd = nda
+ result = outer(a,b).reshape(as+bs)
+ axis = nd-1
+ for k in xrange(nd):
result = concatenate(result, axis=axis)
if wrapper is not None:
result = wrapper(result)
Modified: trunk/numpy/lib/tests/test_shape_base.py
===================================================================
--- trunk/numpy/lib/tests/test_shape_base.py 2006-10-08 13:16:13 UTC (rev 3292)
+++ trunk/numpy/lib/tests/test_shape_base.py 2006-10-09 07:47:06 UTC (rev 3293)
@@ -11,8 +11,6 @@
a = ones((20,10),'d')
assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
def check_simple101(self,level=11):
- # This test causes segmentation fault (Numeric 23.3,23.6,Python 2.3.4)
- # when enabled and shape(a)[1]>100. See Issue 202.
a = ones((10,101),'d')
assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
@@ -370,6 +368,7 @@
assert_equal(type(kron(a,ma)), ndarray)
assert_equal(type(kron(ma,a)), myarray)
+
class test_tile(NumpyTestCase):
def check_basic(self):
a = array([0,1,2])
@@ -380,7 +379,19 @@
assert_equal(tile(b, 2), [[1,2,1,2],[3,4,3,4]])
assert_equal(tile(b,(2,1)),[[1,2],[3,4],[1,2],[3,4]])
assert_equal(tile(b,(2,2)),[[1,2,1,2],[3,4,3,4],[1,2,1,2],[3,4,3,4]])
-
+
+ def check_kroncompare(self):
+ import numpy.random as nr
+ reps=[(2,),(1,2),(2,1),(2,2),(2,3,2),(3,2)]
+ shape=[(3,),(2,3),(3,4,3),(3,2,3),(4,3,2,4),(2,2)]
+ for s in shape:
+ b = nr.randint(0,10,size=s)
+ for r in reps:
+ a = ones(r, b.dtype)
+ large = tile(b, r)
+ klarge = kron(a, b)
+ assert_equal(large, klarge)
+
# Utility
def compare_results(res,desired):
More information about the Numpy-svn
mailing list