[pypy-commit] pypy numpypy-array_prepare_-array_wrap: Implement __array_prepare__ for non-scalar

rguillebert noreply at buildbot.pypy.org
Mon Nov 18 18:32:00 CET 2013


Author: Romain Guillebert <romain.py at gmail.com>
Branch: numpypy-array_prepare_-array_wrap
Changeset: r68218:25afd81e613b
Date: 2013-11-18 18:31 +0100
http://bitbucket.org/pypy/pypy/changeset/25afd81e613b/

Log:	Implement __array_prepare__ for non-scalar

diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -12,22 +12,12 @@
 from pypy.module.micronumpy.iter import PureShapeIterator
 from pypy.module.micronumpy import constants
 from pypy.module.micronumpy.support import int_w
+from pypy.module.micronumpy import interp_boxes
 
-call2_driver = jit.JitDriver(name='numpy_call2',
-                             greens = ['shapelen', 'func', 'calc_dtype',
-                                       'res_dtype'],
-                             reds = ['shape', 'w_lhs', 'w_rhs', 'out',
-                                     'left_iter', 'right_iter', 'out_iter'])
-
-def call_prepare(self, space, w_out, w_obj, w_result):
-    if isinstance(w_out, W_NDimArray):
-        w_array = space.lookup(w_out, "__array_prepare__")
-        w_caller = w_out
-    else:
-        w_array = space.lookup(w_obj, "__array_prepare__")
-        w_caller = w_obj
+def call_prepare(space, w_obj, w_result):
+    w_array = space.lookup(w_obj, "__array_prepare__")
     if w_array:
-        w_retVal = space.get_and_call_function(w_array, w_caller, w_result, None)
+        w_retVal = space.get_and_call_function(w_array, w_obj, w_result, None)
         if not isinstance(w_retVal, W_NDimArray) and \
             not isinstance(w_retVal, interp_boxes.Box):
             raise OperationError(space.w_ValueError,
@@ -50,6 +40,11 @@
         return w_retVal
     return w_result
 
+call2_driver = jit.JitDriver(name='numpy_call2',
+                             greens = ['shapelen', 'func', 'calc_dtype',
+                                       'res_dtype'],
+                             reds = ['shape', 'w_lhs', 'w_rhs', 'out',
+                                     'left_iter', 'right_iter', 'out_iter'])
 def call2(space, shape, func, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     # handle array_priority
     # w_lhs and w_rhs could be of different ndarray subtypes. Numpy does:
@@ -78,6 +73,10 @@
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype,
                                      w_instance=lhs_for_subtype)
+        out = call_prepare(space, w_lhs, out)
+    else:
+        out = call_prepare(space, out, out)
+
     left_iter = w_lhs.create_iter(shape)
     right_iter = w_rhs.create_iter(shape)
     out_iter = out.create_iter(shape)
@@ -107,6 +106,9 @@
 def call1(space, shape, func, calc_dtype, res_dtype, w_obj, out):
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype, w_instance=w_obj)
+        out = call_prepare(space, w_obj, out)
+    else:
+        out = call_prepare(space, out, out)
     obj_iter = w_obj.create_iter(shape)
     out_iter = out.create_iter(shape)
     shapelen = len(shape)
diff --git a/pypy/module/micronumpy/test/test_subtype.py b/pypy/module/micronumpy/test/test_subtype.py
--- a/pypy/module/micronumpy/test/test_subtype.py
+++ b/pypy/module/micronumpy/test/test_subtype.py
@@ -260,7 +260,7 @@
         assert type(x) == ndarray
         assert a.called_wrap
 
-    def test___array_prepare__2arg(self):
+    def test___array_prepare__2arg_scalar(self):
         from numpypy import ndarray, array, add, ones
         class with_prepare(ndarray):
             def __array_prepare__(self, arr, context):
@@ -287,7 +287,7 @@
         assert x.called_prepare
         raises(TypeError, add, a, b, out=c)
 
-    def test___array_prepare__1arg(self):
+    def test___array_prepare__1arg_scalar(self):
         from numpypy import ndarray, array, log, ones
         class with_prepare(ndarray):
             def __array_prepare__(self, arr, context):
@@ -316,6 +316,61 @@
         assert x.called_prepare
         raises(TypeError, log, a, out=c)
 
+    def test___array_prepare__2arg_array(self):
+        from numpypy import ndarray, array, add, ones
+        class with_prepare(ndarray):
+            def __array_prepare__(self, arr, context):
+                retVal = array(arr).view(type=with_prepare)
+                retVal.called_prepare = True
+                return retVal
+        class with_prepare_fail(ndarray):
+            called_prepare = False
+            def __array_prepare__(self, arr, context):
+                return array(arr[0]).view(type=with_prepare)
+        a = array([1])
+        b = array([1]).view(type=with_prepare)
+        x = add(a, a, out=b)
+        assert x == 2
+        assert type(x) == with_prepare
+        assert x.called_prepare
+        b.called_prepare = False
+        a = ones((3, 2)).view(type=with_prepare)
+        b = ones((3, 2))
+        c = ones((3, 2)).view(type=with_prepare_fail)
+        x = add(a, b, out=a)
+        assert (x == 2).all()
+        assert type(x) == with_prepare
+        assert x.called_prepare
+        raises(TypeError, add, a, b, out=c)
+
+    def test___array_prepare__1arg_array(self):
+        from numpypy import ndarray, array, log, ones
+        class with_prepare(ndarray):
+            def __array_prepare__(self, arr, context):
+                retVal = array(arr).view(type=with_prepare)
+                retVal.called_prepare = True
+                return retVal
+        class with_prepare_fail(ndarray):
+            def __array_prepare__(self, arr, context):
+                return array(arr[0]).view(type=with_prepare)
+        a = array([1])
+        b = array([1]).view(type=with_prepare)
+        print 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
+        x = log(a, out=b)
+        print 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
+        assert x == 0
+        print 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
+        assert type(x) == with_prepare
+        assert x.called_prepare
+        x.called_prepare = False
+        a = ones((3, 2)).view(type=with_prepare)
+        b = ones((3, 2))
+        c = ones((3, 2)).view(type=with_prepare_fail)
+        x = log(a)
+        assert (x == 0).all()
+        assert type(x) == with_prepare
+        assert x.called_prepare
+        raises(TypeError, log, a, out=c)
 
     def test___array_prepare__reduce(self):
         from numpypy import ndarray, array, sum, ones, add


More information about the pypy-commit mailing list