[pypy-commit] pypy default: Unroll some functions in numpy correctly.
alex_gaynor
noreply at buildbot.pypy.org
Wed Mar 28 19:33:05 CEST 2012
Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch:
Changeset: r54043:cf91e948ab75
Date: 2012-03-28 13:32 -0400
http://bitbucket.org/pypy/pypy/changeset/cf91e948ab75/
Log: Unroll some functions in numpy correctly.
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -3,9 +3,11 @@
from pypy.interpreter.gateway import interp2app, unwrap_spec, NoneNotWrapped
from pypy.interpreter.typedef import TypeDef, GetSetProperty, interp_attrproperty
from pypy.module.micronumpy import interp_boxes, interp_dtype, support, loop
+from pypy.rlib import jit
from pypy.rlib.rarithmetic import LONG_BIT
from pypy.tool.sourcetools import func_with_new_name
+
class W_Ufunc(Wrappable):
_attrs_ = ["name", "promote_to_float", "promote_bools", "identity"]
_immutable_fields_ = ["promote_to_float", "promote_bools", "name"]
@@ -179,7 +181,7 @@
elif out.shape != shape:
raise operationerrfmt(space.w_ValueError,
'output parameter shape mismatch, expecting [%s]' +
- ' , got [%s]',
+ ' , got [%s]',
",".join([str(x) for x in shape]),
",".join([str(x) for x in out.shape]),
)
@@ -204,7 +206,7 @@
else:
arr = ReduceArray(self.func, self.name, self.identity, obj, dtype)
val = loop.compute(arr)
- return val
+ return val
def do_axis_reduce(self, obj, dtype, axis, result):
from pypy.module.micronumpy.interp_numarray import AxisReduce
@@ -253,7 +255,7 @@
if isinstance(w_obj, Scalar):
arr = self.func(calc_dtype, w_obj.value.convert_to(calc_dtype))
if isinstance(out,Scalar):
- out.value=arr
+ out.value = arr
elif isinstance(out, BaseArray):
out.fill(space, arr)
else:
@@ -265,7 +267,7 @@
if not broadcast_shape or broadcast_shape != out.shape:
raise operationerrfmt(space.w_ValueError,
'output parameter shape mismatch, could not broadcast [%s]' +
- ' to [%s]',
+ ' to [%s]',
",".join([str(x) for x in w_obj.shape]),
",".join([str(x) for x in out.shape]),
)
@@ -292,10 +294,11 @@
self.func = func
self.comparison_func = comparison_func
+ @jit.unroll_safe
def call(self, space, args_w):
from pypy.module.micronumpy.interp_numarray import (Call2,
convert_to_array, Scalar, shape_agreement, BaseArray)
- if len(args_w)>2:
+ if len(args_w) > 2:
[w_lhs, w_rhs, w_out] = args_w
else:
[w_lhs, w_rhs] = args_w
@@ -326,7 +329,7 @@
w_rhs.value.convert_to(calc_dtype)
)
if isinstance(out,Scalar):
- out.value=arr
+ out.value = arr
elif isinstance(out, BaseArray):
out.fill(space, arr)
else:
@@ -337,7 +340,7 @@
if out and out.shape != shape_agreement(space, new_shape, out.shape):
raise operationerrfmt(space.w_ValueError,
'output parameter shape mismatch, could not broadcast [%s]' +
- ' to [%s]',
+ ' to [%s]',
",".join([str(x) for x in new_shape]),
",".join([str(x) for x in out.shape]),
)
@@ -347,7 +350,6 @@
w_lhs.add_invalidates(w_res)
w_rhs.add_invalidates(w_res)
if out:
- #out.add_invalidates(w_res) #causes a recursion loop
w_res.get_concrete()
return w_res
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -1,6 +1,7 @@
from pypy.rlib import jit
from pypy.interpreter.error import OperationError
+ at jit.look_inside_iff(lambda chunks: jit.isconstant(len(chunks)))
def enumerate_chunks(chunks):
result = []
i = -1
@@ -85,9 +86,9 @@
space.isinstance_w(w_item_or_slice, space.w_slice)):
raise OperationError(space.w_IndexError,
space.wrap('unsupported iterator index'))
-
+
start, stop, step, lngth = space.decode_index4(w_item_or_slice, size)
-
+
coords = [0] * len(shape)
i = start
if order == 'C':
diff --git a/pypy/module/micronumpy/support.py b/pypy/module/micronumpy/support.py
--- a/pypy/module/micronumpy/support.py
+++ b/pypy/module/micronumpy/support.py
@@ -1,3 +1,7 @@
+from pypy.rlib import jit
+
+
+ at jit.look_inside_iff(lambda s: jit.isconstant(len(s)))
def product(s):
i = 1
for x in s:
More information about the pypy-commit
mailing list