[pypy-commit] pypy default: optimize call2 when lhs or rhs is a scalar
bdkearns
noreply at buildbot.pypy.org
Thu Dec 4 05:38:58 CET 2014
Author: Brian Kearns <bdkearns at gmail.com>
Branch:
Changeset: r74805:378bfef11c09
Date: 2014-12-03 23:30 -0500
http://bitbucket.org/pypy/pypy/changeset/378bfef11c09/
Log: optimize call2 when lhs or rhs is a scalar
diff --git a/pypy/module/micronumpy/flatiter.py b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -22,6 +22,9 @@
def get_shape(self):
return self.shape
+ def get_size(self):
+ return self.base().get_size()
+
def create_iter(self, shape=None, backward_broadcast=False):
assert isinstance(self.base(), W_NDimArray)
return self.base().create_iter()
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -42,23 +42,38 @@
# TODO handle __array_priorities__ and maybe flip the order
+ if w_lhs.get_size() == 1:
+ w_left = w_lhs.get_scalar_value().convert_to(space, calc_dtype)
+ left_iter = left_state = None
+ else:
+ w_left = None
+ left_iter, left_state = w_lhs.create_iter(shape)
+ left_iter.track_index = False
+
+ if w_rhs.get_size() == 1:
+ w_right = w_rhs.get_scalar_value().convert_to(space, calc_dtype)
+ right_iter = right_state = None
+ else:
+ w_right = None
+ right_iter, right_state = w_rhs.create_iter(shape)
+ right_iter.track_index = False
+
if out is None:
out = W_NDimArray.from_shape(space, shape, res_dtype,
w_instance=lhs_for_subtype)
- left_iter, left_state = w_lhs.create_iter(shape)
- right_iter, right_state = w_rhs.create_iter(shape)
out_iter, out_state = out.create_iter(shape)
- left_iter.track_index = right_iter.track_index = False
shapelen = len(shape)
while not out_iter.done(out_state):
call2_driver.jit_merge_point(shapelen=shapelen, func=func,
calc_dtype=calc_dtype, res_dtype=res_dtype)
- w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
- w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
+ if left_iter:
+ w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
+ left_state = left_iter.next(left_state)
+ if right_iter:
+ w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
+ right_state = right_iter.next(right_state)
out_iter.setitem(out_state, func(calc_dtype, w_left, w_right).convert_to(
space, res_dtype))
- left_state = left_iter.next(left_state)
- right_state = right_iter.next(right_state)
out_state = out_iter.next(out_state)
return out
@@ -68,11 +83,12 @@
reds='auto')
def call1(space, shape, func, calc_dtype, res_dtype, w_obj, out):
+ obj_iter, obj_state = w_obj.create_iter(shape)
+ obj_iter.track_index = False
+
if out is None:
out = W_NDimArray.from_shape(space, shape, res_dtype, w_instance=w_obj)
- obj_iter, obj_state = w_obj.create_iter(shape)
out_iter, out_state = out.create_iter(shape)
- obj_iter.track_index = False
shapelen = len(shape)
while not out_iter.done(out_state):
call1_driver.jit_merge_point(shapelen=shapelen, func=func,
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -102,14 +102,13 @@
assert result == 3 + 3
self.check_trace_count(1)
self.check_simple_loop({
- 'arraylen_gc': 1,
'float_add': 1,
'guard_false': 1,
'guard_not_invalidated': 1,
- 'int_add': 4,
+ 'int_add': 3,
'int_ge': 1,
'jump': 1,
- 'raw_load': 2,
+ 'raw_load': 1,
'raw_store': 1,
})
@@ -124,21 +123,18 @@
assert result == 3 ** 2
self.check_trace_count(1)
self.check_simple_loop({
- 'arraylen_gc': 1,
'call': 1,
- 'float_add': 1,
- 'float_eq': 3,
+ 'float_eq': 2,
'float_mul': 2,
- 'float_ne': 1,
'getarrayitem_raw': 1, # read the errno
- 'guard_false': 4,
+ 'guard_false': 2,
'guard_not_invalidated': 1,
'guard_true': 2,
- 'int_add': 4,
+ 'int_add': 3,
'int_ge': 1,
'int_is_true': 1,
'jump': 1,
- 'raw_load': 2,
+ 'raw_load': 1,
'raw_store': 1,
'setarrayitem_raw': 1, # write the errno
})
@@ -157,14 +153,13 @@
self.check_trace_count(2) # extra one for the astype
del get_stats().loops[0] # we don't care about it
self.check_simple_loop({
- 'arraylen_gc': 1,
'call': 1,
'guard_false': 1,
'guard_not_invalidated': 1,
- 'int_add': 4,
+ 'int_add': 3,
'int_ge': 1,
'jump': 1,
- 'raw_load': 2,
+ 'raw_load': 1,
'raw_store': 1,
})
More information about the pypy-commit
mailing list