[pypy-commit] pypy numppy-flatitter: add jit_merge_point s, add tests for them, tests fail

mattip noreply at buildbot.pypy.org
Fri Jan 27 14:20:16 CET 2012


Author: mattip
Branch: numppy-flatitter
Changeset: r51852:ef26cbf01db4
Date: 2012-01-27 15:19 +0200
http://bitbucket.org/pypy/pypy/changeset/ef26cbf01db4/

Log:	add jit_merge_point s, add tests for them, tests fail

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -59,6 +59,18 @@
     name='numpy_filterset',
 )
 
+flat_get_driver = jit.JitDriver(
+    greens=['shapelen', 'base'],
+    reds=['step', 'ri', 'basei', 'res'],
+    name='numpy_flatget',
+)
+
+flat_set_driver = jit.JitDriver(
+    greens=['shapelen', 'base'],
+    reds=['step', 'ai', 'lngth', 'arr', 'basei'],
+    name='numpy_flatset',
+)
+
 def _find_shape_and_elems(space, w_iterable):
     shape = [space.len_w(w_iterable)]
     batch = space.listview(w_iterable)
@@ -1482,26 +1494,34 @@
                             space.wrap(self.index))
         return space.newtuple([space.wrap(c) for c in coords])
 
+    @jit.unroll_safe
     def descr_getitem(self, space, w_idx):
         if not (space.isinstance_w(w_idx, space.w_int) or
             space.isinstance_w(w_idx, space.w_slice)):
             raise OperationError(space.w_IndexError,
                                  space.wrap('unsupported iterator index'))
-        start, stop, step, lngth = space.decode_index4(w_idx, self.base.size)
+        base = self.base
+        start, stop, step, lngth = space.decode_index4(w_idx, base.size)
         # setslice would have been better, but flat[u:v] for arbitrary
         # shapes of array a cannot be represented as a[x1:x2, y1:y2]
-        basei = ViewIterator(self.base.start, self.base.strides,
-                               self.base.backstrides,self.base.shape)
-        shapelen = len(self.base.shape)
+        basei = ViewIterator(base.start, base.strides,
+                               base.backstrides,base.shape)
+        shapelen = len(base.shape)
         basei = basei.next_skip_x(shapelen, start)
         if lngth <2:
-            return self.base.getitem(basei.offset)
+            return base.getitem(basei.offset)
         ri = ArrayIterator(lngth)
-        res = W_NDimArray(lngth, [lngth], self.base.dtype,
-                                    self.base.order)
+        res = W_NDimArray(lngth, [lngth], base.dtype,
+                                    base.order)
         while not ri.done():
-            # TODO: add a jit_merge_point?
-            w_val = self.base.getitem(basei.offset)
+            flat_get_driver.jit_merge_point(shapelen=shapelen,
+                                             base=base,
+                                             basei=basei,
+                                             step=step,
+                                             res=res,
+                                             ri=ri,
+                                            ) 
+            w_val = base.getitem(basei.offset)
             res.setitem(ri.offset,w_val)
             basei = basei.next_skip_x(shapelen, step)
             ri = ri.next(shapelen)
@@ -1512,20 +1532,30 @@
             space.isinstance_w(w_idx, space.w_slice)):
             raise OperationError(space.w_IndexError,
                                  space.wrap('unsupported iterator index'))
-        start, stop, step, lngth = space.decode_index4(w_idx, self.base.size)
+        base = self.base
+        start, stop, step, lngth = space.decode_index4(w_idx, base.size)
         arr = convert_to_array(space, w_value)
         ai = 0
-        basei = ViewIterator(self.base.start, self.base.strides,
-                               self.base.backstrides,self.base.shape)
-        shapelen = len(self.base.shape)
+        basei = ViewIterator(base.start, base.strides,
+                               base.backstrides,base.shape)
+        shapelen = len(base.shape)
         basei = basei.next_skip_x(shapelen, start)
-        for i in range(lngth):
+        while lngth > 0:
+            flat_set_driver.jit_merge_point(shapelen=shapelen,
+                                             basei=basei,
+                                             base=base,
+                                             step=step,
+                                             arr=arr,
+                                             ai=ai,
+                                             lngth=lngth,
+                                            ) 
             # TODO: add jit_merge_point?
-            v = arr.getitem(ai).convert_to(self.base.dtype)
-            self.base.setitem(basei.offset, v)
+            v = arr.getitem(ai).convert_to(base.dtype)
+            base.setitem(basei.offset, v)
             # need to repeat input values until all assignments are done
             ai = (ai + 1) % arr.size
             basei = basei.next_skip_x(shapelen, step)
+            lngth -= 1
 
     def create_sig(self):
         return signature.FlatSignature(self.base.dtype)
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -389,13 +389,28 @@
         return '''
         a = |30|
         b = flat(a)
-        b -> 6
+        b -> 4: -> 6
         '''
 
     def test_flat_getitem(self):
         result = self.run("flat_getitem")
-        assert result == 6.0
-        self.check_trace_count(0)
+        assert result == 10.0
+        #self.check_trace_count(1)
+        #self.check_simple_loop({})
+
+    def define_flat_setitem():
+        return '''
+        a = |30|
+        b = flat(a)
+        b -> 4: = a->:26
+        a -> 5
+        '''
+
+    def test_flat_setitem(self):
+        result = self.run("flat_setitem")
+        assert result == 1.0
+        #self.check_trace_count(1)
+        #self.check_simple_loop({})
 
 class TestNumpyOld(LLJitMixin):
     def setup_class(cls):


More information about the pypy-commit mailing list