[pypy-commit] pypy numpypy-out: expose 'out' arguments, need lots of tests

mattip noreply at buildbot.pypy.org
Mon Feb 13 00:57:55 CET 2012


Author: mattip
Branch: numpypy-out
Changeset: r52405:04f3673a1de7
Date: 2012-02-13 01:57 +0200
http://bitbucket.org/pypy/pypy/changeset/04f3673a1de7/

Log:	expose 'out' arguments, need lots of tests

diff --git a/pypy/module/micronumpy/interp_boxes.py b/pypy/module/micronumpy/interp_boxes.py
--- a/pypy/module/micronumpy/interp_boxes.py
+++ b/pypy/module/micronumpy/interp_boxes.py
@@ -59,21 +59,24 @@
         return space.wrap(dtype.itemtype.bool(self))
 
     def _binop_impl(ufunc_name):
-        def impl(self, space, w_other):
+        def impl(self, space, w_other, w_out=None):
             from pypy.module.micronumpy import interp_ufuncs
-            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, [self, w_other])
+            return getattr(interp_ufuncs.get(space), ufunc_name).call(space,
+                                                            [self, w_other, w_out])
         return func_with_new_name(impl, "binop_%s_impl" % ufunc_name)
 
     def _binop_right_impl(ufunc_name):
-        def impl(self, space, w_other):
+        def impl(self, space, w_other, w_out=None):
             from pypy.module.micronumpy import interp_ufuncs
-            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, [w_other, self])
+            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, 
+                                                            [w_other, self, w_out])
         return func_with_new_name(impl, "binop_right_%s_impl" % ufunc_name)
 
     def _unaryop_impl(ufunc_name):
-        def impl(self, space):
+        def impl(self, space, w_out=None):
             from pypy.module.micronumpy import interp_ufuncs
-            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, [self])
+            return getattr(interp_ufuncs.get(space), ufunc_name).call(space,
+                                                                    [self, w_out])
         return func_with_new_name(impl, "unaryop_%s_impl" % ufunc_name)
 
     descr_add = _binop_impl("add")
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -83,8 +83,9 @@
         return space.wrap(W_NDimArray(size, shape[:], dtype=dtype))
 
     def _unaryop_impl(ufunc_name):
-        def impl(self, space):
-            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, [self])
+        def impl(self, space, w_out=None):
+            return getattr(interp_ufuncs.get(space), ufunc_name).call(space,
+                                                                [self, w_out])
         return func_with_new_name(impl, "unaryop_%s_impl" % ufunc_name)
 
     descr_pos = _unaryop_impl("positive")
@@ -93,8 +94,9 @@
     descr_invert = _unaryop_impl("invert")
 
     def _binop_impl(ufunc_name):
-        def impl(self, space, w_other):
-            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, [self, w_other])
+        def impl(self, space, w_other, w_out=None):
+            return getattr(interp_ufuncs.get(space), ufunc_name).call(space,
+                                                        [self, w_other, w_out])
         return func_with_new_name(impl, "binop_%s_impl" % ufunc_name)
 
     descr_add = _binop_impl("add")
@@ -123,12 +125,12 @@
         return space.newtuple([w_quotient, w_remainder])
 
     def _binop_right_impl(ufunc_name):
-        def impl(self, space, w_other):
+        def impl(self, space, w_other, w_out=None):
             w_other = scalar_w(space,
                 interp_ufuncs.find_dtype_for_scalar(space, w_other, self.find_dtype()),
                 w_other
             )
-            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, [w_other, self])
+            return getattr(interp_ufuncs.get(space), ufunc_name).call(space, [w_other, self, w_out])
         return func_with_new_name(impl, "binop_right_%s_impl" % ufunc_name)
 
     descr_radd = _binop_right_impl("add")
@@ -155,11 +157,11 @@
                 axis = -1
             else:
                 axis = space.int_w(w_axis)
-            if space.is_w(w_out, space.w_None):
+            if space.is_w(w_out, space.w_None) or not w_out:
                 out = None
             elif not isinstance(w_out, BaseArray):
-                raise OperationError(space.w_TypeError, space.wrap(
-                                                    'output must be an array'))
+                raise OperationError(space.w_TypeError, space.wrap( 
+                        'output must be an array'))
             else:
                 out = w_out
             return getattr(interp_ufuncs.get(space), ufunc_name).reduce(space,
@@ -215,14 +217,15 @@
     descr_argmax = _reduce_argmax_argmin_impl("max")
     descr_argmin = _reduce_argmax_argmin_impl("min")
 
-    def descr_dot(self, space, w_other):
+    def descr_dot(self, space, w_other, w_out=None):
         other = convert_to_array(space, w_other)
         if isinstance(other, Scalar):
+            #Note: w_out is not modified, this is numpy compliant.
             return self.descr_mul(space, other)
         elif len(self.shape) < 2 and len(other.shape) < 2:
-            w_res = self.descr_mul(space, other)
+            w_res = self.descr_mul(space, other, w_out)
             assert isinstance(w_res, BaseArray)
-            return w_res.descr_sum(space, space.wrap(-1))
+            return w_res.descr_sum(space, space.wrap(-1), w_out)
         dtype = interp_ufuncs.find_binop_result_dtype(space,
                                      self.find_dtype(), other.find_dtype())
         if self.size < 1 and other.size < 1:
@@ -707,11 +710,12 @@
     """
     Class for representing virtual arrays, such as binary ops or ufuncs
     """
-    def __init__(self, name, shape, res_dtype):
+    def __init__(self, name, shape, res_dtype, out_arg=None):
         BaseArray.__init__(self, shape)
         self.forced_result = None
         self.res_dtype = res_dtype
         self.name = name
+        self.res = out_arg
 
     def _del_sources(self):
         # Function for deleting references to source arrays,
@@ -719,7 +723,8 @@
         raise NotImplementedError
 
     def compute(self):
-        ra = ResultArray(self, self.size, self.shape, self.res_dtype)
+        ra = ResultArray(self, self.size, self.shape, self.res_dtype,
+                                                                self.res)
         loop.compute(ra)
         return ra.left
 
@@ -766,8 +771,9 @@
 
 
 class Call1(VirtualArray):
-    def __init__(self, ufunc, name, shape, calc_dtype, res_dtype, values):
-        VirtualArray.__init__(self, name, shape, res_dtype)
+    def __init__(self, ufunc, name, shape, calc_dtype, res_dtype, values,
+                                                            out_arg=None):
+        VirtualArray.__init__(self, name, shape, res_dtype, out_arg)
         self.values = values
         self.size = values.size
         self.ufunc = ufunc
@@ -788,8 +794,9 @@
     """
     _immutable_fields_ = ['left', 'right']
 
-    def __init__(self, ufunc, name, shape, calc_dtype, res_dtype, left, right):
-        VirtualArray.__init__(self, name, shape, res_dtype)
+    def __init__(self, ufunc, name, shape, calc_dtype, res_dtype, left, right,
+            out_arg=None):
+        VirtualArray.__init__(self, name, shape, res_dtype, out_arg)
         self.ufunc = ufunc
         self.left = left
         self.right = right
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -28,14 +28,18 @@
         return self.identity
 
     def descr_call(self, space, __args__):
+        from interp_numarray import BaseArray
         args_w, kwds_w = __args__.unpack()
         # it occurs to me that we don't support any datatypes that
         # require casting, change it later when we do
         kwds_w.pop('casting', None)
         w_subok = kwds_w.pop('subok', None)
         w_out = kwds_w.pop('out', space.w_None)
-        if ((w_subok is not None and space.is_true(w_subok)) or
-            not space.is_w(w_out, space.w_None)):
+        if space.is_w(w_out, space.w_None):
+            out = None
+        else:
+            out = w_out
+        if (w_subok is not None and space.is_true(w_subok)):
             raise OperationError(space.w_NotImplementedError,
                                  space.wrap("parameters unsupported"))
         if kwds_w or len(args_w) < self.argcount:
@@ -43,11 +47,14 @@
                 space.wrap("invalid number of arguments")
             )
         elif len(args_w) > self.argcount:
-            # The extra arguments should actually be the output array, but we
-            # don't support that yet.
             raise OperationError(space.w_TypeError,
                 space.wrap("invalid number of arguments")
             )
+        elif out is not None:
+            args_w = args_w[:] + [out]
+        if args_w[-1] and not isinstance(args_w[-1], BaseArray):
+            raise OperationError(space.w_TypeError, space.wrap(
+                                            'output must be an array'))
         return self.call(space, args_w)
 
     @unwrap_spec(skipna=bool, keepdims=bool)
@@ -105,6 +112,7 @@
         array([[ 1,  5],
                [ 9, 13]])
         """
+        from pypy.module.micronumpy.interp_numarray import BaseArray
         if w_axis is None:
             axis = 0
         elif space.is_w(w_axis, space.w_None):
@@ -113,7 +121,7 @@
             axis = space.int_w(w_axis)
         if space.is_w(w_out, space.w_None):
             out = None
-        elif not isinstance(w_out, W_NDimArray):
+        elif not isinstance(w_out, BaseArray):
             raise OperationError(space.w_TypeError, space.wrap(
                                                 'output must be an array'))
         else:
@@ -165,8 +173,11 @@
                         ' does not have enough dimensions', self.name)
                 elif out.shape != shape:
                     raise operationerrfmt(space.w_ValueError,
-                        'output parameter shape mismatch, expecting %s' +
-                        ' , got %s', str(shape), str(out.shape))
+                        'output parameter shape mismatch, expecting [%s]' +
+                        ' , got [%s]', 
+                        ",".join([str(x) for x in shape]),
+                        ",".join([str(x) for x in out.shape]),
+                        )
                 #Test for dtype agreement, perhaps create an itermediate
                 #if out.dtype != dtype:
                 #    raise OperationError(space.w_TypeError, space.wrap(
@@ -182,9 +193,12 @@
                               " dimensions",self.name)
             arr = ReduceArray(self.func, self.name, self.identity, obj,
                                                             out.find_dtype())
+            val = loop.compute(arr)
+            assert isinstance(out, Scalar)
+            out.value = val
         else:
             arr = ReduceArray(self.func, self.name, self.identity, obj, dtype)
-        val = loop.compute(arr)
+            val = loop.compute(arr)
         return val 
 
     def do_axis_reduce(self, obj, dtype, axis, result):
@@ -211,7 +225,7 @@
         from pypy.module.micronumpy.interp_numarray import (Call1,
             convert_to_array, Scalar)
 
-        [w_obj] = args_w
+        [w_obj, w_out] = args_w
         w_obj = convert_to_array(space, w_obj)
         calc_dtype = find_unaryop_result_dtype(space,
                                   w_obj.find_dtype(),
@@ -244,17 +258,25 @@
 
     def call(self, space, args_w):
         from pypy.module.micronumpy.interp_numarray import (Call2,
-            convert_to_array, Scalar, shape_agreement)
+            convert_to_array, Scalar, shape_agreement, BaseArray)
 
-        [w_lhs, w_rhs] = args_w
+        [w_lhs, w_rhs, w_out] = args_w
         w_lhs = convert_to_array(space, w_lhs)
         w_rhs = convert_to_array(space, w_rhs)
-        calc_dtype = find_binop_result_dtype(space,
-            w_lhs.find_dtype(), w_rhs.find_dtype(),
-            int_only=self.int_only,
-            promote_to_float=self.promote_to_float,
-            promote_bools=self.promote_bools,
-        )
+        if space.is_w(w_out, space.w_None) or not w_out:
+            out = None
+            calc_dtype = find_binop_result_dtype(space,
+                w_lhs.find_dtype(), w_rhs.find_dtype(),
+                int_only=self.int_only,
+                promote_to_float=self.promote_to_float,
+                promote_bools=self.promote_bools,
+            )
+        elif not isinstance(w_out, BaseArray):
+            raise OperationError(space.w_TypeError, space.wrap(
+                    'output must be an array'))
+        else:
+            out = w_out
+            calc_dtype = out.find_dtype()
         if self.comparison_func:
             res_dtype = interp_dtype.get_dtype_cache(space).w_booldtype
         else:
@@ -265,9 +287,10 @@
                 w_rhs.value.convert_to(calc_dtype)
             ))
         new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
+        # Test correctness of out.shape
         w_res = Call2(self.func, self.name,
                       new_shape, calc_dtype,
-                      res_dtype, w_lhs, w_rhs)
+                      res_dtype, w_lhs, w_rhs, out)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
         return w_res
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -862,7 +862,7 @@
         assert (arange(10).reshape(5, 2).mean(axis=1) == [0.5, 2.5, 4.5, 6.5, 8.5]).all()
 
     def test_sum(self):
-        from _numpypy import array,zeros
+        from _numpypy import array
         a = array(range(5))
         assert a.sum() == 10
         assert a[:4].sum() == 6
@@ -874,7 +874,7 @@
         d = array(0.)
         b = a.sum(out=d)
         assert b == d
-        assert b.dtype == d.dtype
+        assert isinstance(b, float)
 
     def test_reduce_nd(self):
         from numpypy import arange, array, multiply


More information about the pypy-commit mailing list