[pypy-commit] pypy vecopt-merge-iterator-sharing: (plan_rich, ronan) first working version that generates all five possible call2 combinations that shares the iterators

plan_rich noreply at buildbot.pypy.org
Fri Aug 14 18:50:29 CEST 2015


Author: Richard Plangger <rich at pasra.at>
Branch: vecopt-merge-iterator-sharing
Changeset: r78988:690ba1eaa6a8
Date: 2015-08-14 18:50 +0200
http://bitbucket.org/pypy/pypy/changeset/690ba1eaa6a8/

Log:	(plan_rich, ronan) first working version that generates all five
	possible call2 combinations that shares the iterators

diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -88,7 +88,6 @@
             return self.iterator.same_shape(other.iterator)
         return False
 
-
 class ArrayIter(object):
     _immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 'shape_m1[*]',
                           'strides[*]', 'backstrides[*]', 'factors[*]',
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -2,6 +2,7 @@
 operations. This is the place to look for all the computations that iterate
 over all the array elements.
 """
+import py
 from pypy.interpreter.error import OperationError
 from rpython.rlib import jit
 from rpython.rlib.rstring import StringBuilder
@@ -13,11 +14,6 @@
 from pypy.interpreter.argument import Arguments
 
 
-call2_driver = jit.JitDriver(
-    name='numpy_call2',
-    greens=['shapelen','state_count', 'left_index', 'right_index', 'left', 'right', 'func', 'calc_dtype', 'res_dtype'],
-    reds='auto', vectorize=True)
-
 def call2(space, shape, func, calc_dtype, w_lhs, w_rhs, out):
     if w_lhs.get_size() == 1:
         w_left = w_lhs.get_scalar_value().convert_to(space, calc_dtype)
@@ -40,68 +36,102 @@
     res_dtype = out.get_dtype()
 
     states = [out_state,left_state,right_state]
-    out_index = 0
     left_index = 1
     right_index = 2
     # left == right == out
     # left == right
     # left == out
     # right == out
+    params = (space, shapelen, func, calc_dtype, res_dtype, out,
+              w_left, w_right, left_iter, right_iter, out_iter,
+              left_state, right_state, out_state)
     if not right_iter:
+        # rhs is a scalar
         del states[2]
     else:
+        # rhs is NOT a scalar
         if out_state.same(right_state):
             # (1) out and right are the same -> remove right
             right_index = 0
             del states[2]
+    #
     if not left_iter:
+        # lhs is a scalar
         del states[1]
         if right_index == 2:
             right_index = 1
+            return call2_advance_out_right(*params)
     else:
+        # lhs is NOT a scalar
         if out_state.same(left_state):
             # (2) out and left are the same -> remove left
             left_index = 0
             del states[1]
             if right_index == 2:
                 right_index = 1
+            return call2_advance_out_right(*params)
         else:
             if len(states) == 3: # did not enter (1)
                 if right_iter and right_state.same(left_state):
                     right_index = 1
                     del states[2]
+                    return call2_advance_out_left_eq_right(*params)
+                else:
+                    # worst case
+                    return call2_advance_out_left_right(*params)
+            else:
+                return call2_advance_out_left(*params)
+
     state_count = len(states)
+    if state_count == 1:
+        return call2_advance_out(*params)
+
+    assert 0, "logical problem with the selection of the call 2 case"
+
+def generate_call2_cases(name, left_state, right_state):
+    call2_driver = jit.JitDriver(name='numpy_call2_' + name,
+        greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+        reds='auto', vectorize=True)
     #
-    while not out_iter.done(states[0]):
-        call2_driver.jit_merge_point(shapelen=shapelen,
-                                     func=func,
-                                     left=left_iter is None,
-                                     right=right_iter is None,
-                                     state_count=state_count,
-                                     left_index=left_index,
-                                     right_index=right_index,
-                                     calc_dtype=calc_dtype,
-                                     res_dtype=res_dtype)
-        if left_iter:
-            left_state = states[left_index]
-            w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
-        if right_iter:
-            right_state = states[right_index]
-            w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
-        w_out = func(calc_dtype, w_left, w_right)
-        out_iter.setitem(states[0], w_out.convert_to(space, res_dtype))
-        #
-        for i,state in enumerate(states):
-            states[i] = state.iterator.next(state)
+    advance_left_state = left_state == "left_state"
+    advance_right_state = right_state == "right_state"
+    code = """
+    def method(space, shapelen, func, calc_dtype, res_dtype, out,
+               w_left, w_right, left_iter, right_iter, out_iter,
+               left_state, right_state, out_state):
+        while not out_iter.done(out_state):
+            call2_driver.jit_merge_point(shapelen=shapelen, func=func,
+                    calc_dtype=calc_dtype, res_dtype=res_dtype)
+            if left_iter:
+                w_left = left_iter.getitem({left_state}).convert_to(space, calc_dtype)
+            if right_iter:
+                w_right = right_iter.getitem({right_state}).convert_to(space, calc_dtype)
+            w_out = func(calc_dtype, w_left, w_right)
+            out_iter.setitem(out_state, w_out.convert_to(space, res_dtype))
+            out_state = out_iter.next(out_state)
+            if advance_left_state and left_iter:
+                left_state = left_iter.next(left_state)
+            if advance_right_state and right_iter:
+                right_state = right_iter.next(right_state)
+            #
+            # if not set to None, the values will be loop carried
+            # (for the var,var case), forcing the vectorization to unpack
+            # the vector registers at the end of the loop
+            if left_iter:
+                w_left = None
+            if right_iter:
+                w_right = None
+        return out
+    """
+    exec(py.code.Source(code.format(left_state=left_state,right_state=right_state)).compile(), locals())
+    method.__name__ = "call2_" + name
+    return method
 
-        # if not set to None, the values will be loop carried
-        # (for the var,var case), forcing the vectorization to unpack
-        # the vector registers at the end of the loop
-        if left_iter:
-            w_left = None
-        if right_iter:
-            w_right = None
-    return out
+call2_advance_out = generate_call2_cases("inc_out", "out_state", "out_state")
+call2_advance_out_left = generate_call2_cases("inc_out_left", "left_state", "out_state")
+call2_advance_out_right = generate_call2_cases("inc_out_right", "out_state", "right_state")
+call2_advance_out_left_eq_right = generate_call2_cases("inc_out_left_eq_right", "left_state", "left_state")
+call2_advance_out_left_right = generate_call2_cases("inc_out_left_right", "left_state", "right_state")
 
 call1_driver = jit.JitDriver(
     name='numpy_call1',
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -911,8 +911,10 @@
     def test_multidim_slice(self):
         result = self.run('multidim_slice')
         assert result == 12
-        self.check_trace_count(2)
-        self.check_vectorized(1,0) # TODO?
+        self.check_trace_count(3)
+        # ::2 creates a view object -> needs an inner loop
+        # that iterates continous chunks of the matrix
+        self.check_vectorized(1,1) 
 
     # NOT WORKING
 


More information about the pypy-commit mailing list