[pypy-commit] pypy scalar-operations: Convert ufunc args to scalars rather than arrays when possible
rlamy
noreply at buildbot.pypy.org
Mon Jun 30 03:05:48 CEST 2014
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: scalar-operations
Changeset: r72272:fbf38c2a6bf6
Date: 2014-06-29 19:21 +0100
http://bitbucket.org/pypy/pypy/changeset/fbf38c2a6bf6/
Log: Convert ufunc args to scalars rather than arrays when possible
diff --git a/pypy/module/micronumpy/base.py b/pypy/module/micronumpy/base.py
--- a/pypy/module/micronumpy/base.py
+++ b/pypy/module/micronumpy/base.py
@@ -90,6 +90,14 @@
w_val = dtype.coerce(space, space.wrap(0))
return convert_to_array(space, w_val)
+ @staticmethod
+ def from_scalar(space, w_scalar):
+ """Convert a scalar into a 0-dim array"""
+ dtype = w_scalar.get_dtype(space)
+ w_arr = W_NDimArray.from_shape(space, [], dtype)
+ w_arr.set_scalar_value(w_scalar)
+ return w_arr
+
def convert_to_array(space, w_obj):
from pypy.module.micronumpy.ctors import array
diff --git a/pypy/module/micronumpy/ctors.py b/pypy/module/micronumpy/ctors.py
--- a/pypy/module/micronumpy/ctors.py
+++ b/pypy/module/micronumpy/ctors.py
@@ -4,7 +4,8 @@
from rpython.rlib.rstring import strip_spaces
from rpython.rtyper.lltypesystem import lltype, rffi
from pypy.module.micronumpy import descriptor, loop
-from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
+from pypy.module.micronumpy.base import (
+ W_NDimArray, convert_to_array, W_NumpyObject)
from pypy.module.micronumpy.converters import shape_converter
@@ -100,6 +101,17 @@
return w_arr
+def numpify(space, w_object):
+ """Convert the object to a W_NumpyObject"""
+ if isinstance(w_object, W_NumpyObject):
+ return w_object
+ w_res = array(space, w_object)
+ if w_res.is_scalar():
+ return w_res.get_scalar_value()
+ else:
+ return w_res
+
+
def zeros(space, w_shape, w_dtype=None, w_order=None):
dtype = space.interp_w(descriptor.W_Dtype,
space.call_function(space.gettypefor(descriptor.W_Dtype), w_dtype))
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -7,6 +7,7 @@
from rpython.tool.sourcetools import func_with_new_name
from pypy.module.micronumpy import boxes, descriptor, loop, constants as NPY
from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
+from pypy.module.micronumpy.ctors import numpify
from pypy.module.micronumpy.strides import shape_agreement
@@ -17,6 +18,13 @@
def done_if_false(dtype, val):
return not dtype.itemtype.bool(val)
+def _get_dtype(space, w_npyobj):
+ if isinstance(w_npyobj, boxes.W_GenericBox):
+ return w_npyobj.get_dtype(space)
+ else:
+ assert isinstance(w_npyobj, W_NDimArray)
+ return w_npyobj.get_dtype()
+
class W_Ufunc(W_Root):
_immutable_fields_ = [
@@ -385,15 +393,10 @@
else:
[w_lhs, w_rhs] = args_w
w_out = None
- if (isinstance(w_lhs, boxes.W_GenericBox) and
- isinstance(w_rhs, boxes.W_GenericBox)):
- w_ldtype = w_lhs.get_dtype(space)
- w_rdtype = w_rhs.get_dtype(space)
- else:
- w_lhs = convert_to_array(space, w_lhs)
- w_rhs = convert_to_array(space, w_rhs)
- w_ldtype = w_lhs.get_dtype()
- w_rdtype = w_rhs.get_dtype()
+ w_lhs = numpify(space, w_lhs)
+ w_rhs = numpify(space, w_rhs)
+ w_ldtype = _get_dtype(space, w_lhs)
+ w_rdtype = _get_dtype(space, w_rhs)
if w_ldtype.is_str() and w_rdtype.is_str() and \
self.comparison_func:
pass
@@ -456,7 +459,11 @@
else:
out = arr
return out
+ if isinstance(w_lhs, boxes.W_GenericBox):
+ w_lhs = W_NDimArray.from_scalar(space, w_lhs)
assert isinstance(w_lhs, W_NDimArray)
+ if isinstance(w_rhs, boxes.W_GenericBox):
+ w_rhs = W_NDimArray.from_scalar(space, w_rhs)
assert isinstance(w_rhs, W_NDimArray)
new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)
More information about the pypy-commit
mailing list