[pypy-commit] pypy default: Unroll some functions in numpy correctly.

alex_gaynor noreply at buildbot.pypy.org
Wed Mar 28 19:33:05 CEST 2012


Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: 
Changeset: r54043:cf91e948ab75
Date: 2012-03-28 13:32 -0400
http://bitbucket.org/pypy/pypy/changeset/cf91e948ab75/

Log:	Unroll some functions in numpy correctly.

diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -3,9 +3,11 @@
 from pypy.interpreter.gateway import interp2app, unwrap_spec, NoneNotWrapped
 from pypy.interpreter.typedef import TypeDef, GetSetProperty, interp_attrproperty
 from pypy.module.micronumpy import interp_boxes, interp_dtype, support, loop
+from pypy.rlib import jit
 from pypy.rlib.rarithmetic import LONG_BIT
 from pypy.tool.sourcetools import func_with_new_name
 
+
 class W_Ufunc(Wrappable):
     _attrs_ = ["name", "promote_to_float", "promote_bools", "identity"]
     _immutable_fields_ = ["promote_to_float", "promote_bools", "name"]
@@ -179,7 +181,7 @@
                 elif out.shape != shape:
                     raise operationerrfmt(space.w_ValueError,
                         'output parameter shape mismatch, expecting [%s]' +
-                        ' , got [%s]', 
+                        ' , got [%s]',
                         ",".join([str(x) for x in shape]),
                         ",".join([str(x) for x in out.shape]),
                         )
@@ -204,7 +206,7 @@
         else:
             arr = ReduceArray(self.func, self.name, self.identity, obj, dtype)
             val = loop.compute(arr)
-        return val 
+        return val
 
     def do_axis_reduce(self, obj, dtype, axis, result):
         from pypy.module.micronumpy.interp_numarray import AxisReduce
@@ -253,7 +255,7 @@
         if isinstance(w_obj, Scalar):
             arr = self.func(calc_dtype, w_obj.value.convert_to(calc_dtype))
             if isinstance(out,Scalar):
-                out.value=arr
+                out.value = arr
             elif isinstance(out, BaseArray):
                 out.fill(space, arr)
             else:
@@ -265,7 +267,7 @@
             if not broadcast_shape or broadcast_shape != out.shape:
                 raise operationerrfmt(space.w_ValueError,
                     'output parameter shape mismatch, could not broadcast [%s]' +
-                    ' to [%s]', 
+                    ' to [%s]',
                     ",".join([str(x) for x in w_obj.shape]),
                     ",".join([str(x) for x in out.shape]),
                     )
@@ -292,10 +294,11 @@
         self.func = func
         self.comparison_func = comparison_func
 
+    @jit.unroll_safe
     def call(self, space, args_w):
         from pypy.module.micronumpy.interp_numarray import (Call2,
             convert_to_array, Scalar, shape_agreement, BaseArray)
-        if len(args_w)>2:
+        if len(args_w) > 2:
             [w_lhs, w_rhs, w_out] = args_w
         else:
             [w_lhs, w_rhs] = args_w
@@ -326,7 +329,7 @@
                 w_rhs.value.convert_to(calc_dtype)
             )
             if isinstance(out,Scalar):
-                out.value=arr
+                out.value = arr
             elif isinstance(out, BaseArray):
                 out.fill(space, arr)
             else:
@@ -337,7 +340,7 @@
         if out and out.shape != shape_agreement(space, new_shape, out.shape):
             raise operationerrfmt(space.w_ValueError,
                 'output parameter shape mismatch, could not broadcast [%s]' +
-                ' to [%s]', 
+                ' to [%s]',
                 ",".join([str(x) for x in new_shape]),
                 ",".join([str(x) for x in out.shape]),
                 )
@@ -347,7 +350,6 @@
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
         if out:
-            #out.add_invalidates(w_res) #causes a recursion loop
             w_res.get_concrete()
         return w_res
 
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -1,6 +1,7 @@
 from pypy.rlib import jit
 from pypy.interpreter.error import OperationError
 
+ at jit.look_inside_iff(lambda chunks: jit.isconstant(len(chunks)))
 def enumerate_chunks(chunks):
     result = []
     i = -1
@@ -85,9 +86,9 @@
         space.isinstance_w(w_item_or_slice, space.w_slice)):
         raise OperationError(space.w_IndexError,
                              space.wrap('unsupported iterator index'))
-            
+
     start, stop, step, lngth = space.decode_index4(w_item_or_slice, size)
-        
+
     coords = [0] * len(shape)
     i = start
     if order == 'C':
diff --git a/pypy/module/micronumpy/support.py b/pypy/module/micronumpy/support.py
--- a/pypy/module/micronumpy/support.py
+++ b/pypy/module/micronumpy/support.py
@@ -1,3 +1,7 @@
+from pypy.rlib import jit
+
+
+ at jit.look_inside_iff(lambda s: jit.isconstant(len(s)))
 def product(s):
     i = 1
     for x in s:


More information about the pypy-commit mailing list