[pypy-commit] pypy vecopt-merge-iterator-sharing: (plan_rich, ronan) first working version that generates all five possible call2 combinations that shares the iterators
plan_rich
noreply at buildbot.pypy.org
Fri Aug 14 18:50:29 CEST 2015
Author: Richard Plangger <rich at pasra.at>
Branch: vecopt-merge-iterator-sharing
Changeset: r78988:690ba1eaa6a8
Date: 2015-08-14 18:50 +0200
http://bitbucket.org/pypy/pypy/changeset/690ba1eaa6a8/
Log: (plan_rich, ronan) first working version that generates all five
possible call2 combinations that shares the iterators
diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -88,7 +88,6 @@
return self.iterator.same_shape(other.iterator)
return False
-
class ArrayIter(object):
_immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 'shape_m1[*]',
'strides[*]', 'backstrides[*]', 'factors[*]',
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -2,6 +2,7 @@
operations. This is the place to look for all the computations that iterate
over all the array elements.
"""
+import py
from pypy.interpreter.error import OperationError
from rpython.rlib import jit
from rpython.rlib.rstring import StringBuilder
@@ -13,11 +14,6 @@
from pypy.interpreter.argument import Arguments
-call2_driver = jit.JitDriver(
- name='numpy_call2',
- greens=['shapelen','state_count', 'left_index', 'right_index', 'left', 'right', 'func', 'calc_dtype', 'res_dtype'],
- reds='auto', vectorize=True)
-
def call2(space, shape, func, calc_dtype, w_lhs, w_rhs, out):
if w_lhs.get_size() == 1:
w_left = w_lhs.get_scalar_value().convert_to(space, calc_dtype)
@@ -40,68 +36,102 @@
res_dtype = out.get_dtype()
states = [out_state,left_state,right_state]
- out_index = 0
left_index = 1
right_index = 2
# left == right == out
# left == right
# left == out
# right == out
+ params = (space, shapelen, func, calc_dtype, res_dtype, out,
+ w_left, w_right, left_iter, right_iter, out_iter,
+ left_state, right_state, out_state)
if not right_iter:
+ # rhs is a scalar
del states[2]
else:
+ # rhs is NOT a scalar
if out_state.same(right_state):
# (1) out and right are the same -> remove right
right_index = 0
del states[2]
+ #
if not left_iter:
+ # lhs is a scalar
del states[1]
if right_index == 2:
right_index = 1
+ return call2_advance_out_right(*params)
else:
+ # lhs is NOT a scalar
if out_state.same(left_state):
# (2) out and left are the same -> remove left
left_index = 0
del states[1]
if right_index == 2:
right_index = 1
+ return call2_advance_out_right(*params)
else:
if len(states) == 3: # did not enter (1)
if right_iter and right_state.same(left_state):
right_index = 1
del states[2]
+ return call2_advance_out_left_eq_right(*params)
+ else:
+ # worst case
+ return call2_advance_out_left_right(*params)
+ else:
+ return call2_advance_out_left(*params)
+
state_count = len(states)
+ if state_count == 1:
+ return call2_advance_out(*params)
+
+ assert 0, "logical problem with the selection of the call 2 case"
+
+def generate_call2_cases(name, left_state, right_state):
+ call2_driver = jit.JitDriver(name='numpy_call2_' + name,
+ greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+ reds='auto', vectorize=True)
#
- while not out_iter.done(states[0]):
- call2_driver.jit_merge_point(shapelen=shapelen,
- func=func,
- left=left_iter is None,
- right=right_iter is None,
- state_count=state_count,
- left_index=left_index,
- right_index=right_index,
- calc_dtype=calc_dtype,
- res_dtype=res_dtype)
- if left_iter:
- left_state = states[left_index]
- w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
- if right_iter:
- right_state = states[right_index]
- w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
- w_out = func(calc_dtype, w_left, w_right)
- out_iter.setitem(states[0], w_out.convert_to(space, res_dtype))
- #
- for i,state in enumerate(states):
- states[i] = state.iterator.next(state)
+ advance_left_state = left_state == "left_state"
+ advance_right_state = right_state == "right_state"
+ code = """
+ def method(space, shapelen, func, calc_dtype, res_dtype, out,
+ w_left, w_right, left_iter, right_iter, out_iter,
+ left_state, right_state, out_state):
+ while not out_iter.done(out_state):
+ call2_driver.jit_merge_point(shapelen=shapelen, func=func,
+ calc_dtype=calc_dtype, res_dtype=res_dtype)
+ if left_iter:
+ w_left = left_iter.getitem({left_state}).convert_to(space, calc_dtype)
+ if right_iter:
+ w_right = right_iter.getitem({right_state}).convert_to(space, calc_dtype)
+ w_out = func(calc_dtype, w_left, w_right)
+ out_iter.setitem(out_state, w_out.convert_to(space, res_dtype))
+ out_state = out_iter.next(out_state)
+ if advance_left_state and left_iter:
+ left_state = left_iter.next(left_state)
+ if advance_right_state and right_iter:
+ right_state = right_iter.next(right_state)
+ #
+ # if not set to None, the values will be loop carried
+ # (for the var,var case), forcing the vectorization to unpack
+ # the vector registers at the end of the loop
+ if left_iter:
+ w_left = None
+ if right_iter:
+ w_right = None
+ return out
+ """
+ exec(py.code.Source(code.format(left_state=left_state,right_state=right_state)).compile(), locals())
+ method.__name__ = "call2_" + name
+ return method
- # if not set to None, the values will be loop carried
- # (for the var,var case), forcing the vectorization to unpack
- # the vector registers at the end of the loop
- if left_iter:
- w_left = None
- if right_iter:
- w_right = None
- return out
+call2_advance_out = generate_call2_cases("inc_out", "out_state", "out_state")
+call2_advance_out_left = generate_call2_cases("inc_out_left", "left_state", "out_state")
+call2_advance_out_right = generate_call2_cases("inc_out_right", "out_state", "right_state")
+call2_advance_out_left_eq_right = generate_call2_cases("inc_out_left_eq_right", "left_state", "left_state")
+call2_advance_out_left_right = generate_call2_cases("inc_out_left_right", "left_state", "right_state")
call1_driver = jit.JitDriver(
name='numpy_call1',
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -911,8 +911,10 @@
def test_multidim_slice(self):
result = self.run('multidim_slice')
assert result == 12
- self.check_trace_count(2)
- self.check_vectorized(1,0) # TODO?
+ self.check_trace_count(3)
+ # ::2 creates a view object -> needs an inner loop
+ # that iterates continous chunks of the matrix
+ self.check_vectorized(1,1)
# NOT WORKING
More information about the pypy-commit
mailing list