[pypy-commit] pypy numppy-flatitter: add jit_merge_point s, add tests for them, tests fail
mattip
noreply at buildbot.pypy.org
Fri Jan 27 14:20:16 CET 2012
Author: mattip
Branch: numppy-flatitter
Changeset: r51852:ef26cbf01db4
Date: 2012-01-27 15:19 +0200
http://bitbucket.org/pypy/pypy/changeset/ef26cbf01db4/
Log: add jit_merge_point s, add tests for them, tests fail
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -59,6 +59,18 @@
name='numpy_filterset',
)
+flat_get_driver = jit.JitDriver(
+ greens=['shapelen', 'base'],
+ reds=['step', 'ri', 'basei', 'res'],
+ name='numpy_flatget',
+)
+
+flat_set_driver = jit.JitDriver(
+ greens=['shapelen', 'base'],
+ reds=['step', 'ai', 'lngth', 'arr', 'basei'],
+ name='numpy_flatset',
+)
+
def _find_shape_and_elems(space, w_iterable):
shape = [space.len_w(w_iterable)]
batch = space.listview(w_iterable)
@@ -1482,26 +1494,34 @@
space.wrap(self.index))
return space.newtuple([space.wrap(c) for c in coords])
+ @jit.unroll_safe
def descr_getitem(self, space, w_idx):
if not (space.isinstance_w(w_idx, space.w_int) or
space.isinstance_w(w_idx, space.w_slice)):
raise OperationError(space.w_IndexError,
space.wrap('unsupported iterator index'))
- start, stop, step, lngth = space.decode_index4(w_idx, self.base.size)
+ base = self.base
+ start, stop, step, lngth = space.decode_index4(w_idx, base.size)
# setslice would have been better, but flat[u:v] for arbitrary
# shapes of array a cannot be represented as a[x1:x2, y1:y2]
- basei = ViewIterator(self.base.start, self.base.strides,
- self.base.backstrides,self.base.shape)
- shapelen = len(self.base.shape)
+ basei = ViewIterator(base.start, base.strides,
+ base.backstrides,base.shape)
+ shapelen = len(base.shape)
basei = basei.next_skip_x(shapelen, start)
if lngth <2:
- return self.base.getitem(basei.offset)
+ return base.getitem(basei.offset)
ri = ArrayIterator(lngth)
- res = W_NDimArray(lngth, [lngth], self.base.dtype,
- self.base.order)
+ res = W_NDimArray(lngth, [lngth], base.dtype,
+ base.order)
while not ri.done():
- # TODO: add a jit_merge_point?
- w_val = self.base.getitem(basei.offset)
+ flat_get_driver.jit_merge_point(shapelen=shapelen,
+ base=base,
+ basei=basei,
+ step=step,
+ res=res,
+ ri=ri,
+ )
+ w_val = base.getitem(basei.offset)
res.setitem(ri.offset,w_val)
basei = basei.next_skip_x(shapelen, step)
ri = ri.next(shapelen)
@@ -1512,20 +1532,30 @@
space.isinstance_w(w_idx, space.w_slice)):
raise OperationError(space.w_IndexError,
space.wrap('unsupported iterator index'))
- start, stop, step, lngth = space.decode_index4(w_idx, self.base.size)
+ base = self.base
+ start, stop, step, lngth = space.decode_index4(w_idx, base.size)
arr = convert_to_array(space, w_value)
ai = 0
- basei = ViewIterator(self.base.start, self.base.strides,
- self.base.backstrides,self.base.shape)
- shapelen = len(self.base.shape)
+ basei = ViewIterator(base.start, base.strides,
+ base.backstrides,base.shape)
+ shapelen = len(base.shape)
basei = basei.next_skip_x(shapelen, start)
- for i in range(lngth):
+ while lngth > 0:
+ flat_set_driver.jit_merge_point(shapelen=shapelen,
+ basei=basei,
+ base=base,
+ step=step,
+ arr=arr,
+ ai=ai,
+ lngth=lngth,
+ )
# TODO: add jit_merge_point?
- v = arr.getitem(ai).convert_to(self.base.dtype)
- self.base.setitem(basei.offset, v)
+ v = arr.getitem(ai).convert_to(base.dtype)
+ base.setitem(basei.offset, v)
# need to repeat input values until all assignments are done
ai = (ai + 1) % arr.size
basei = basei.next_skip_x(shapelen, step)
+ lngth -= 1
def create_sig(self):
return signature.FlatSignature(self.base.dtype)
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -389,13 +389,28 @@
return '''
a = |30|
b = flat(a)
- b -> 6
+ b -> 4: -> 6
'''
def test_flat_getitem(self):
result = self.run("flat_getitem")
- assert result == 6.0
- self.check_trace_count(0)
+ assert result == 10.0
+ #self.check_trace_count(1)
+ #self.check_simple_loop({})
+
+ def define_flat_setitem():
+ return '''
+ a = |30|
+ b = flat(a)
+ b -> 4: = a->:26
+ a -> 5
+ '''
+
+ def test_flat_setitem(self):
+ result = self.run("flat_setitem")
+ assert result == 1.0
+ #self.check_trace_count(1)
+ #self.check_simple_loop({})
class TestNumpyOld(LLJitMixin):
def setup_class(cls):
More information about the pypy-commit
mailing list