[pypy-commit] pypy scalar-operations: Convert ufunc args to scalars rather than arrays when possible

rlamy noreply at buildbot.pypy.org
Mon Jun 30 03:05:48 CEST 2014


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: scalar-operations
Changeset: r72272:fbf38c2a6bf6
Date: 2014-06-29 19:21 +0100
http://bitbucket.org/pypy/pypy/changeset/fbf38c2a6bf6/

Log:	Convert ufunc args to scalars rather than arrays when possible

diff --git a/pypy/module/micronumpy/base.py b/pypy/module/micronumpy/base.py
--- a/pypy/module/micronumpy/base.py
+++ b/pypy/module/micronumpy/base.py
@@ -90,6 +90,14 @@
             w_val = dtype.coerce(space, space.wrap(0))
         return convert_to_array(space, w_val)
 
+    @staticmethod
+    def from_scalar(space, w_scalar):
+        """Convert a scalar into a 0-dim array"""
+        dtype = w_scalar.get_dtype(space)
+        w_arr = W_NDimArray.from_shape(space, [], dtype)
+        w_arr.set_scalar_value(w_scalar)
+        return w_arr
+
 
 def convert_to_array(space, w_obj):
     from pypy.module.micronumpy.ctors import array
diff --git a/pypy/module/micronumpy/ctors.py b/pypy/module/micronumpy/ctors.py
--- a/pypy/module/micronumpy/ctors.py
+++ b/pypy/module/micronumpy/ctors.py
@@ -4,7 +4,8 @@
 from rpython.rlib.rstring import strip_spaces
 from rpython.rtyper.lltypesystem import lltype, rffi
 from pypy.module.micronumpy import descriptor, loop
-from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
+from pypy.module.micronumpy.base import (
+    W_NDimArray, convert_to_array, W_NumpyObject)
 from pypy.module.micronumpy.converters import shape_converter
 
 
@@ -100,6 +101,17 @@
     return w_arr
 
 
+def numpify(space, w_object):
+    """Convert the object to a W_NumpyObject"""
+    if isinstance(w_object, W_NumpyObject):
+        return w_object
+    w_res = array(space, w_object)
+    if w_res.is_scalar():
+        return w_res.get_scalar_value()
+    else:
+        return w_res
+
+
 def zeros(space, w_shape, w_dtype=None, w_order=None):
     dtype = space.interp_w(descriptor.W_Dtype,
         space.call_function(space.gettypefor(descriptor.W_Dtype), w_dtype))
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -7,6 +7,7 @@
 from rpython.tool.sourcetools import func_with_new_name
 from pypy.module.micronumpy import boxes, descriptor, loop, constants as NPY
 from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
+from pypy.module.micronumpy.ctors import numpify
 from pypy.module.micronumpy.strides import shape_agreement
 
 
@@ -17,6 +18,13 @@
 def done_if_false(dtype, val):
     return not dtype.itemtype.bool(val)
 
+def _get_dtype(space, w_npyobj):
+    if isinstance(w_npyobj, boxes.W_GenericBox):
+        return w_npyobj.get_dtype(space)
+    else:
+        assert isinstance(w_npyobj, W_NDimArray)
+        return w_npyobj.get_dtype()
+
 
 class W_Ufunc(W_Root):
     _immutable_fields_ = [
@@ -385,15 +393,10 @@
         else:
             [w_lhs, w_rhs] = args_w
             w_out = None
-        if (isinstance(w_lhs, boxes.W_GenericBox) and
-                isinstance(w_rhs, boxes.W_GenericBox)):
-            w_ldtype = w_lhs.get_dtype(space)
-            w_rdtype = w_rhs.get_dtype(space)
-        else:
-            w_lhs = convert_to_array(space, w_lhs)
-            w_rhs = convert_to_array(space, w_rhs)
-            w_ldtype = w_lhs.get_dtype()
-            w_rdtype = w_rhs.get_dtype()
+        w_lhs = numpify(space, w_lhs)
+        w_rhs = numpify(space, w_rhs)
+        w_ldtype = _get_dtype(space, w_lhs)
+        w_rdtype = _get_dtype(space, w_rhs)
         if w_ldtype.is_str() and w_rdtype.is_str() and \
                 self.comparison_func:
             pass
@@ -456,7 +459,11 @@
             else:
                 out = arr
             return out
+        if isinstance(w_lhs, boxes.W_GenericBox):
+            w_lhs = W_NDimArray.from_scalar(space, w_lhs)
         assert isinstance(w_lhs, W_NDimArray)
+        if isinstance(w_rhs, boxes.W_GenericBox):
+            w_rhs = W_NDimArray.from_scalar(space, w_rhs)
         assert isinstance(w_rhs, W_NDimArray)
         new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
         new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)


More information about the pypy-commit mailing list