[pypy-commit] pypy default: optimize call2 when lhs or rhs is a scalar

bdkearns noreply at buildbot.pypy.org
Thu Dec 4 05:38:58 CET 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r74805:378bfef11c09
Date: 2014-12-03 23:30 -0500
http://bitbucket.org/pypy/pypy/changeset/378bfef11c09/

Log:	optimize call2 when lhs or rhs is a scalar

diff --git a/pypy/module/micronumpy/flatiter.py b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -22,6 +22,9 @@
     def get_shape(self):
         return self.shape
 
+    def get_size(self):
+        return self.base().get_size()
+
     def create_iter(self, shape=None, backward_broadcast=False):
         assert isinstance(self.base(), W_NDimArray)
         return self.base().create_iter()
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -42,23 +42,38 @@
 
     # TODO handle __array_priorities__ and maybe flip the order
 
+    if w_lhs.get_size() == 1:
+        w_left = w_lhs.get_scalar_value().convert_to(space, calc_dtype)
+        left_iter = left_state = None
+    else:
+        w_left = None
+        left_iter, left_state = w_lhs.create_iter(shape)
+        left_iter.track_index = False
+
+    if w_rhs.get_size() == 1:
+        w_right = w_rhs.get_scalar_value().convert_to(space, calc_dtype)
+        right_iter = right_state = None
+    else:
+        w_right = None
+        right_iter, right_state = w_rhs.create_iter(shape)
+        right_iter.track_index = False
+
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype,
                                      w_instance=lhs_for_subtype)
-    left_iter, left_state = w_lhs.create_iter(shape)
-    right_iter, right_state = w_rhs.create_iter(shape)
     out_iter, out_state = out.create_iter(shape)
-    left_iter.track_index = right_iter.track_index = False
     shapelen = len(shape)
     while not out_iter.done(out_state):
         call2_driver.jit_merge_point(shapelen=shapelen, func=func,
                                      calc_dtype=calc_dtype, res_dtype=res_dtype)
-        w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
-        w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
+        if left_iter:
+            w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
+            left_state = left_iter.next(left_state)
+        if right_iter:
+            w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
+            right_state = right_iter.next(right_state)
         out_iter.setitem(out_state, func(calc_dtype, w_left, w_right).convert_to(
             space, res_dtype))
-        left_state = left_iter.next(left_state)
-        right_state = right_iter.next(right_state)
         out_state = out_iter.next(out_state)
     return out
 
@@ -68,11 +83,12 @@
     reds='auto')
 
 def call1(space, shape, func, calc_dtype, res_dtype, w_obj, out):
+    obj_iter, obj_state = w_obj.create_iter(shape)
+    obj_iter.track_index = False
+
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype, w_instance=w_obj)
-    obj_iter, obj_state = w_obj.create_iter(shape)
     out_iter, out_state = out.create_iter(shape)
-    obj_iter.track_index = False
     shapelen = len(shape)
     while not out_iter.done(out_state):
         call1_driver.jit_merge_point(shapelen=shapelen, func=func,
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -102,14 +102,13 @@
         assert result == 3 + 3
         self.check_trace_count(1)
         self.check_simple_loop({
-            'arraylen_gc': 1,
             'float_add': 1,
             'guard_false': 1,
             'guard_not_invalidated': 1,
-            'int_add': 4,
+            'int_add': 3,
             'int_ge': 1,
             'jump': 1,
-            'raw_load': 2,
+            'raw_load': 1,
             'raw_store': 1,
         })
 
@@ -124,21 +123,18 @@
         assert result == 3 ** 2
         self.check_trace_count(1)
         self.check_simple_loop({
-            'arraylen_gc': 1,
             'call': 1,
-            'float_add': 1,
-            'float_eq': 3,
+            'float_eq': 2,
             'float_mul': 2,
-            'float_ne': 1,
             'getarrayitem_raw': 1,     # read the errno
-            'guard_false': 4,
+            'guard_false': 2,
             'guard_not_invalidated': 1,
             'guard_true': 2,
-            'int_add': 4,
+            'int_add': 3,
             'int_ge': 1,
             'int_is_true': 1,
             'jump': 1,
-            'raw_load': 2,
+            'raw_load': 1,
             'raw_store': 1,
             'setarrayitem_raw': 1,     # write the errno
         })
@@ -157,14 +153,13 @@
         self.check_trace_count(2)  # extra one for the astype
         del get_stats().loops[0]   # we don't care about it
         self.check_simple_loop({
-            'arraylen_gc': 1,
             'call': 1,
             'guard_false': 1,
             'guard_not_invalidated': 1,
-            'int_add': 4,
+            'int_add': 3,
             'int_ge': 1,
             'jump': 1,
-            'raw_load': 2,
+            'raw_load': 1,
             'raw_store': 1,
         })
 


More information about the pypy-commit mailing list