[pypy-commit] pypy default: (mattip, fijal) merge numpypy-axisops, this adds axis=x argument to reduce

fijal noreply at buildbot.pypy.org
Sat Jan 14 15:10:49 CET 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: 
Changeset: r51306:5c19878c1ef9
Date: 2012-01-14 16:10 +0200
http://bitbucket.org/pypy/pypy/changeset/5c19878c1ef9/

Log:	(mattip, fijal) merge numpypy-axisops, this adds axis=x argument to
	reduce functions.

diff --git a/pypy/jit/metainterp/history.py b/pypy/jit/metainterp/history.py
--- a/pypy/jit/metainterp/history.py
+++ b/pypy/jit/metainterp/history.py
@@ -1003,16 +1003,16 @@
         return insns
 
     def check_simple_loop(self, expected=None, **check):
-        # Usefull in the simplest case when we have only one trace ending with
-        # a jump back to itself and possibly a few bridges ending with finnish.
-        # Only the operations within the loop formed by that single jump will
-        # be counted.
+        """ Usefull in the simplest case when we have only one trace ending with
+        a jump back to itself and possibly a few bridges.
+        Only the operations within the loop formed by that single jump will
+        be counted.
+        """
         loops = self.get_all_loops()
         assert len(loops) == 1
         loop = loops[0]
         jumpop = loop.operations[-1]
         assert jumpop.getopnum() == rop.JUMP
-        assert self.check_resops(jump=1)
         labels = [op for op in loop.operations if op.getopnum() == rop.LABEL]
         targets = [op._descr_wref() for op in labels]
         assert None not in targets # TargetToken was freed, give up
diff --git a/pypy/module/micronumpy/app_numpy.py b/pypy/module/micronumpy/app_numpy.py
--- a/pypy/module/micronumpy/app_numpy.py
+++ b/pypy/module/micronumpy/app_numpy.py
@@ -19,25 +19,49 @@
         a[i][i] = 1
     return a
 
-def mean(a):
+def mean(a, axis=None):
     if not hasattr(a, "mean"):
         a = _numpypy.array(a)
-    return a.mean()
+    return a.mean(axis)
 
-def sum(a):
+def sum(a,axis=None):
+    '''sum(a, axis=None)
+    Sum of array elements over a given axis.
+    
+    Parameters
+    ----------
+    a : array_like
+        Elements to sum.
+    axis : integer, optional
+        Axis over which the sum is taken. By default `axis` is None,
+        and all elements are summed.
+    
+    Returns
+    -------
+    sum_along_axis : ndarray
+        An array with the same shape as `a`, with the specified
+        axis removed.   If `a` is a 0-d array, or if `axis` is None, a scalar
+        is returned.  If an output array is specified, a reference to
+        `out` is returned.
+    
+    See Also
+    --------
+    ndarray.sum : Equivalent method.
+    '''
+    # TODO: add to doc (once it's implemented): cumsum : Cumulative sum of array elements.
     if not hasattr(a, "sum"):
         a = _numpypy.array(a)
-    return a.sum()
+    return a.sum(axis)
 
-def min(a):
+def min(a, axis=None):
     if not hasattr(a, "min"):
         a = _numpypy.array(a)
-    return a.min()
+    return a.min(axis)
 
-def max(a):
+def max(a, axis=None):
     if not hasattr(a, "max"):
         a = _numpypy.array(a)
-    return a.max()
+    return a.max(axis)
 
 def arange(start, stop=None, step=1, dtype=None):
     '''arange([start], stop[, step], dtype=None)
diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -372,13 +372,17 @@
 
     def execute(self, interp):
         if self.name in SINGLE_ARG_FUNCTIONS:
-            if len(self.args) != 1:
+            if len(self.args) != 1 and self.name != 'sum':
                 raise ArgumentMismatch
             arr = self.args[0].execute(interp)
             if not isinstance(arr, BaseArray):
                 raise ArgumentNotAnArray
             if self.name == "sum":
-                w_res = arr.descr_sum(interp.space)
+                if len(self.args)>1:
+                    w_res = arr.descr_sum(interp.space,
+                                          self.args[1].execute(interp))
+                else:
+                    w_res = arr.descr_sum(interp.space)
             elif self.name == "prod":
                 w_res = arr.descr_prod(interp.space)
             elif self.name == "max":
@@ -416,7 +420,7 @@
     ('\]', 'array_right'),
     ('(->)|[\+\-\*\/]', 'operator'),
     ('=', 'assign'),
-    (',', 'coma'),
+    (',', 'comma'),
     ('\|', 'pipe'),
     ('\(', 'paren_left'),
     ('\)', 'paren_right'),
@@ -504,7 +508,7 @@
         return SliceConstant(start, stop, step)
 
 
-    def parse_expression(self, tokens):
+    def parse_expression(self, tokens, accept_comma=False):
         stack = []
         while tokens.remaining():
             token = tokens.pop()
@@ -524,9 +528,13 @@
                 stack.append(RangeConstant(tokens.pop().v))
                 end = tokens.pop()
                 assert end.name == 'pipe'
+            elif accept_comma and token.name == 'comma':
+                continue
             else:
                 tokens.push()
                 break
+        if accept_comma:
+            return stack
         stack.reverse()
         lhs = stack.pop()
         while stack:
@@ -540,7 +548,7 @@
         args = []
         tokens.pop() # lparen
         while tokens.get(0).name != 'paren_right':
-            args.append(self.parse_expression(tokens))
+            args += self.parse_expression(tokens, accept_comma=True)
         return FunctionCall(name, args)
 
     def parse_array_const(self, tokens):
@@ -556,7 +564,7 @@
             token = tokens.pop()
             if token.name == 'array_right':
                 return elems
-            assert token.name == 'coma'
+            assert token.name == 'comma'
 
     def parse_statement(self, tokens):
         if (tokens.get(0).name == 'identifier' and
diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -1,19 +1,20 @@
 
 from pypy.rlib import jit
 from pypy.rlib.objectmodel import instantiate
-from pypy.module.micronumpy.strides import calculate_broadcast_strides
+from pypy.module.micronumpy.strides import calculate_broadcast_strides,\
+     calculate_slice_strides
 
-# Iterators for arrays
-# --------------------
-# all those iterators with the exception of BroadcastIterator iterate over the
-# entire array in C order (the last index changes the fastest). This will
-# yield all elements. Views iterate over indices and look towards strides and
-# backstrides to find the correct position. Notably the offset between
-# x[..., i + 1] and x[..., i] will be strides[-1]. Offset between
-# x[..., k + 1, 0] and x[..., k, i_max] will be backstrides[-2] etc.
+class BaseTransform(object):
+    pass
 
-# BroadcastIterator works like that, but for indexes that don't change source
-# in the original array, strides[i] == backstrides[i] == 0
+class ViewTransform(BaseTransform):
+    def __init__(self, chunks):
+        # 4-tuple specifying slicing
+        self.chunks = chunks
+
+class BroadcastTransform(BaseTransform):
+    def __init__(self, res_shape):
+        self.res_shape = res_shape
 
 class BaseIterator(object):
     def next(self, shapelen):
@@ -22,6 +23,15 @@
     def done(self):
         raise NotImplementedError
 
+    def apply_transformations(self, arr, transformations):
+        v = self
+        for transform in transformations:
+            v = v.transform(arr, transform)
+        return v
+
+    def transform(self, arr, t):
+        raise NotImplementedError
+
 class ArrayIterator(BaseIterator):
     def __init__(self, size):
         self.offset = 0
@@ -36,6 +46,10 @@
     def done(self):
         return self.offset >= self.size
 
+    def transform(self, arr, t):
+        return ViewIterator(arr.start, arr.strides, arr.backstrides,
+                            arr.shape).transform(arr, t)
+
 class OneDimIterator(BaseIterator):
     def __init__(self, start, step, stop):
         self.offset = start
@@ -52,26 +66,29 @@
     def done(self):
         return self.offset == self.size
 
-def view_iter_from_arr(arr):
-    return ViewIterator(arr.start, arr.strides, arr.backstrides, arr.shape)
-
 class ViewIterator(BaseIterator):
-    def __init__(self, start, strides, backstrides, shape, res_shape=None):
+    def __init__(self, start, strides, backstrides, shape):
         self.offset  = start
         self._done   = False
-        if res_shape is not None and res_shape != shape:
-            r = calculate_broadcast_strides(strides, backstrides,
-                                            shape, res_shape)
-            self.strides, self.backstrides = r
-            self.res_shape = res_shape
-        else:
-            self.strides = strides
-            self.backstrides = backstrides
-            self.res_shape = shape
+        self.strides = strides
+        self.backstrides = backstrides
+        self.res_shape = shape
         self.indices = [0] * len(self.res_shape)
 
+    def transform(self, arr, t):
+        if isinstance(t, BroadcastTransform):
+            r = calculate_broadcast_strides(self.strides, self.backstrides,
+                                            self.res_shape, t.res_shape)
+            return ViewIterator(self.offset, r[0], r[1], t.res_shape)
+        elif isinstance(t, ViewTransform):
+            r = calculate_slice_strides(self.res_shape, self.offset,
+                                        self.strides,
+                                        self.backstrides, t.chunks)
+            return ViewIterator(r[1], r[2], r[3], r[0])
+
     @jit.unroll_safe
     def next(self, shapelen):
+        shapelen = jit.promote(len(self.res_shape))
         offset = self.offset
         indices = [0] * shapelen
         for i in range(shapelen):
@@ -96,6 +113,13 @@
         res._done = done
         return res
 
+    def apply_transformations(self, arr, transformations):
+        v = BaseIterator.apply_transformations(self, arr, transformations)
+        if len(arr.shape) == 1:
+            return OneDimIterator(self.offset, self.strides[0],
+                                  self.res_shape[0])
+        return v
+
     def done(self):
         return self._done
 
@@ -103,11 +127,57 @@
     def next(self, shapelen):
         return self
 
+    def transform(self, arr, t):
+        pass
+
+class AxisIterator(BaseIterator):
+    def __init__(self, start, dim, shape, strides, backstrides):
+        self.res_shape = shape[:]
+        self.strides = strides[:dim] + [0] + strides[dim:]
+        self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
+        self.first_line = True
+        self.indices = [0] * len(shape)
+        self._done = False
+        self.offset = start
+        self.dim = dim
+
+    @jit.unroll_safe
+    def next(self, shapelen):
+        offset = self.offset
+        first_line = self.first_line
+        indices = [0] * shapelen
+        for i in range(shapelen):
+            indices[i] = self.indices[i]
+        done = False
+        for i in range(shapelen - 1, -1, -1):
+            if indices[i] < self.res_shape[i] - 1:
+                if i == self.dim:
+                    first_line = False
+                indices[i] += 1
+                offset += self.strides[i]
+                break
+            else:
+                indices[i] = 0
+                offset -= self.backstrides[i]
+        else:
+            done = True
+        res = instantiate(AxisIterator)
+        res.offset = offset
+        res.indices = indices
+        res.strides = self.strides
+        res.backstrides = self.backstrides
+        res.res_shape = self.res_shape
+        res._done = done
+        res.first_line = first_line
+        res.dim = self.dim
+        return res        
+
+    def done(self):
+        return self._done
+
 # ------ other iterators that are not part of the computation frame ----------
-
-class AxisIterator(object):
-    """ This object will return offsets of each start of the last stride
-    """
+    
+class SkipLastAxisIterator(object):
     def __init__(self, arr):
         self.arr = arr
         self.indices = [0] * (len(arr.shape) - 1)
@@ -125,4 +195,3 @@
                 self.offset -= self.arr.backstrides[i]
         else:
             self.done = True
-        
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -8,8 +8,8 @@
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib.rstring import StringBuilder
-from pypy.module.micronumpy.interp_iter import ArrayIterator,\
-     view_iter_from_arr, OneDimIterator, AxisIterator
+from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\
+     SkipLastAxisIterator
 
 numpy_driver = jit.JitDriver(
     greens=['shapelen', 'sig'],
@@ -35,11 +35,12 @@
 slice_driver = jit.JitDriver(
     greens=['shapelen', 'sig'],
     virtualizables=['frame'],
-    reds=['self', 'frame', 'source', 'res_iter'],
+    reds=['self', 'frame', 'arr'],
     get_printable_location=signature.new_printable_location('slice'),
     name='numpy_slice',
 )
 
+
 def _find_shape_and_elems(space, w_iterable):
     shape = [space.len_w(w_iterable)]
     batch = space.listview(w_iterable)
@@ -286,13 +287,17 @@
     descr_rpow = _binop_right_impl("power")
     descr_rmod = _binop_right_impl("mod")
 
-    def _reduce_ufunc_impl(ufunc_name):
-        def impl(self, space):
-            return getattr(interp_ufuncs.get(space), ufunc_name).reduce(space, self, multidim=True)
+    def _reduce_ufunc_impl(ufunc_name, promote_to_largest=False):
+        def impl(self, space, w_dim=None):
+            if space.is_w(w_dim, space.w_None):
+                w_dim = space.wrap(-1)
+            return getattr(interp_ufuncs.get(space), ufunc_name).reduce(space,
+                                        self, True, promote_to_largest, w_dim)
         return func_with_new_name(impl, "reduce_%s_impl" % ufunc_name)
 
     descr_sum = _reduce_ufunc_impl("add")
-    descr_prod = _reduce_ufunc_impl("multiply")
+    descr_sum_promote = _reduce_ufunc_impl("add", True)
+    descr_prod = _reduce_ufunc_impl("multiply", True)
     descr_max = _reduce_ufunc_impl("maximum")
     descr_min = _reduce_ufunc_impl("minimum")
 
@@ -377,7 +382,7 @@
         else:
             w_res = self.descr_mul(space, w_other)
             assert isinstance(w_res, BaseArray)
-            return w_res.descr_sum(space)
+            return w_res.descr_sum(space, space.wrap(-1))
 
     def get_concrete(self):
         raise NotImplementedError
@@ -565,16 +570,22 @@
             )
         return w_result
 
-    def descr_mean(self, space):
-        return space.div(self.descr_sum(space), space.wrap(self.size))
+    def descr_mean(self, space, w_dim=None):
+        if space.is_w(w_dim, space.w_None):
+            w_dim = space.wrap(-1)
+            w_denom = space.wrap(self.size)
+        else:
+            dim = space.int_w(w_dim)
+            w_denom = space.wrap(self.shape[dim])
+        return space.div(self.descr_sum_promote(space, w_dim), w_denom)
 
     def descr_var(self, space):
         # var = mean((values - mean(values)) ** 2)
-        w_res = self.descr_sub(space, self.descr_mean(space))
+        w_res = self.descr_sub(space, self.descr_mean(space, space.w_None))
         assert isinstance(w_res, BaseArray) 
         w_res = w_res.descr_pow(space, space.wrap(2))
         assert isinstance(w_res, BaseArray)
-        return w_res.descr_mean(space)
+        return w_res.descr_mean(space, space.w_None)
 
     def descr_std(self, space):
         # std(v) = sqrt(var(v))
@@ -613,11 +624,12 @@
     def getitem(self, item):
         raise NotImplementedError
 
-    def find_sig(self, res_shape=None):
+    def find_sig(self, res_shape=None, arr=None):
         """ find a correct signature for the array
         """
         res_shape = res_shape or self.shape
-        return signature.find_sig(self.create_sig(res_shape), self)
+        arr = arr or self
+        return signature.find_sig(self.create_sig(), arr)
 
     def descr_array_iface(self, space):
         if not self.shape:
@@ -671,7 +683,7 @@
     def copy(self, space):
         return Scalar(self.dtype, self.value)
 
-    def create_sig(self, res_shape):
+    def create_sig(self):
         return signature.ScalarSignature(self.dtype)
 
     def get_concrete_or_scalar(self):
@@ -689,7 +701,8 @@
         self.name = name
 
     def _del_sources(self):
-        # Function for deleting references to source arrays, to allow garbage-collecting them
+        # Function for deleting references to source arrays,
+        # to allow garbage-collecting them
         raise NotImplementedError
 
     def compute(self):
@@ -741,11 +754,11 @@
         self.size = size
         VirtualArray.__init__(self, 'slice', shape, child.find_dtype())
 
-    def create_sig(self, res_shape):
+    def create_sig(self):
         if self.forced_result is not None:
-            return self.forced_result.create_sig(res_shape)
+            return self.forced_result.create_sig()
         return signature.VirtualSliceSignature(
-            self.child.create_sig(res_shape))
+            self.child.create_sig())
 
     def force_if_needed(self):
         if self.forced_result is None:
@@ -755,6 +768,7 @@
     def _del_sources(self):
         self.child = None
 
+
 class Call1(VirtualArray):
     def __init__(self, ufunc, name, shape, res_dtype, values):
         VirtualArray.__init__(self, name, shape, res_dtype)
@@ -765,16 +779,17 @@
     def _del_sources(self):
         self.values = None
 
-    def create_sig(self, res_shape):
+    def create_sig(self):
         if self.forced_result is not None:
-            return self.forced_result.create_sig(res_shape)
-        return signature.Call1(self.ufunc, self.name,
-                               self.values.create_sig(res_shape))
+            return self.forced_result.create_sig()
+        return signature.Call1(self.ufunc, self.name, self.values.create_sig())
 
 class Call2(VirtualArray):
     """
     Intermediate class for performing binary operations.
     """
+    _immutable_fields_ = ['left', 'right']
+    
     def __init__(self, ufunc, name, shape, calc_dtype, res_dtype, left, right):
         VirtualArray.__init__(self, name, shape, res_dtype)
         self.ufunc = ufunc
@@ -789,12 +804,55 @@
         self.left = None
         self.right = None
 
-    def create_sig(self, res_shape):
+    def create_sig(self):
         if self.forced_result is not None:
-            return self.forced_result.create_sig(res_shape)
+            return self.forced_result.create_sig()
+        if self.shape != self.left.shape and self.shape != self.right.shape:
+            return signature.BroadcastBoth(self.ufunc, self.name,
+                                           self.calc_dtype,
+                                           self.left.create_sig(),
+                                           self.right.create_sig())
+        elif self.shape != self.left.shape:
+            return signature.BroadcastLeft(self.ufunc, self.name,
+                                           self.calc_dtype,
+                                           self.left.create_sig(),
+                                           self.right.create_sig())
+        elif self.shape != self.right.shape:
+            return signature.BroadcastRight(self.ufunc, self.name,
+                                            self.calc_dtype,
+                                            self.left.create_sig(),
+                                            self.right.create_sig())
         return signature.Call2(self.ufunc, self.name, self.calc_dtype,
-                               self.left.create_sig(res_shape),
-                               self.right.create_sig(res_shape))
+                               self.left.create_sig(), self.right.create_sig())
+
+class SliceArray(Call2):
+    def __init__(self, shape, dtype, left, right):
+        Call2.__init__(self, None, 'sliceloop', shape, dtype, dtype, left,
+                       right)
+    
+    def create_sig(self):
+        lsig = self.left.create_sig()
+        rsig = self.right.create_sig()
+        if self.shape != self.right.shape:
+            return signature.SliceloopBroadcastSignature(self.ufunc,
+                                                         self.name,
+                                                         self.calc_dtype,
+                                                         lsig, rsig)
+        return signature.SliceloopSignature(self.ufunc, self.name,
+                                            self.calc_dtype,
+                                            lsig, rsig)
+
+class AxisReduce(Call2):
+    """ NOTE: this is only used as a container, you should never
+    encounter such things in the wild. Remove this comment
+    when we'll make AxisReduce lazy
+    """
+    _immutable_fields_ = ['left', 'right']
+    
+    def __init__(self, ufunc, name, shape, dtype, left, right, dim):
+        Call2.__init__(self, ufunc, name, shape, dtype, dtype,
+                       left, right)
+        self.dim = dim
 
 class ConcreteArray(BaseArray):
     """ An array that have actual storage, whether owned or not
@@ -849,11 +907,6 @@
         self.strides = strides
         self.backstrides = backstrides
 
-    def array_sig(self, res_shape):
-        if res_shape is not None and self.shape != res_shape:
-            return signature.ViewSignature(self.dtype)
-        return signature.ArraySignature(self.dtype)
-
     def to_str(self, space, comma, builder, indent=' ', use_ellipsis=False):
         '''Modifies builder with a representation of the array/slice
         The items will be seperated by a comma if comma is 1
@@ -890,7 +943,7 @@
                     view.to_str(space, comma, builder, indent=indent + ' ',
                                                     use_ellipsis=use_ellipsis)
                 if i < self.shape[0] - 1:
-                    builder.append(ccomma +'\n' + indent + '...' + ncomma)
+                    builder.append(ccomma + '\n' + indent + '...' + ncomma)
                     i = self.shape[0] - 3
                 else:
                     i += 1
@@ -968,20 +1021,22 @@
             self.dtype is w_value.find_dtype()):
             self._fast_setslice(space, w_value)
         else:
-            self._sliceloop(w_value, res_shape)
+            arr = SliceArray(self.shape, self.dtype, self, w_value)
+            self._sliceloop(arr)
 
     def _fast_setslice(self, space, w_value):
         assert isinstance(w_value, ConcreteArray)
         itemsize = self.dtype.itemtype.get_element_size()
-        if len(self.shape) == 1:
+        shapelen = len(self.shape)
+        if shapelen == 1:
             rffi.c_memcpy(
                 rffi.ptradd(self.storage, self.start * itemsize),
                 rffi.ptradd(w_value.storage, w_value.start * itemsize),
                 self.size * itemsize
             )
         else:
-            dest = AxisIterator(self)
-            source = AxisIterator(w_value)
+            dest = SkipLastAxisIterator(self)
+            source = SkipLastAxisIterator(w_value)
             while not dest.done:
                 rffi.c_memcpy(
                     rffi.ptradd(self.storage, dest.offset * itemsize),
@@ -991,21 +1046,16 @@
                 source.next()
                 dest.next()
 
-    def _sliceloop(self, source, res_shape):
-        sig = source.find_sig(res_shape)
-        frame = sig.create_frame(source, res_shape)
-        res_iter = view_iter_from_arr(self)
-        shapelen = len(res_shape)
-        while not res_iter.done():
-            slice_driver.jit_merge_point(sig=sig,
-                                         frame=frame,
-                                         shapelen=shapelen,
-                                         self=self, source=source,
-                                         res_iter=res_iter)
-            self.setitem(res_iter.offset, sig.eval(frame, source).convert_to(
-                self.find_dtype()))
+    def _sliceloop(self, arr):
+        sig = arr.find_sig()
+        frame = sig.create_frame(arr)
+        shapelen = len(self.shape)
+        while not frame.done():
+            slice_driver.jit_merge_point(sig=sig, frame=frame, self=self,
+                                         arr=arr,
+                                         shapelen=shapelen)
+            sig.eval(frame, arr)
             frame.next(shapelen)
-            res_iter = res_iter.next(shapelen)
 
     def copy(self, space):
         array = W_NDimArray(self.size, self.shape[:], self.dtype, self.order)
@@ -1014,7 +1064,7 @@
 
 
 class ViewArray(ConcreteArray):
-    def create_sig(self, res_shape):
+    def create_sig(self):
         return signature.ViewSignature(self.dtype)
 
 
@@ -1078,8 +1128,8 @@
         self.shape = new_shape
         self.calc_strides(new_shape)
 
-    def create_sig(self, res_shape):
-        return self.array_sig(res_shape)
+    def create_sig(self):
+        return signature.ArraySignature(self.dtype)
 
     def __del__(self):
         lltype.free(self.storage, flavor='raw', track_allocation=False)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -3,20 +3,29 @@
 from pypy.interpreter.gateway import interp2app
 from pypy.interpreter.typedef import TypeDef, GetSetProperty, interp_attrproperty
 from pypy.module.micronumpy import interp_boxes, interp_dtype
-from pypy.module.micronumpy.signature import ReduceSignature, ScalarSignature,\
-     find_sig, new_printable_location
+from pypy.module.micronumpy.signature import ReduceSignature,\
+     find_sig, new_printable_location, AxisReduceSignature, ScalarSignature
 from pypy.rlib import jit
 from pypy.rlib.rarithmetic import LONG_BIT
 from pypy.tool.sourcetools import func_with_new_name
 
 reduce_driver = jit.JitDriver(
-    greens = ['shapelen', "sig"],
-    virtualizables = ["frame"],
-    reds = ["frame", "self", "dtype", "value", "obj"],
+    greens=['shapelen', "sig"],
+    virtualizables=["frame"],
+    reds=["frame", "self", "dtype", "value", "obj"],
     get_printable_location=new_printable_location('reduce'),
     name='numpy_reduce',
 )
 
+axisreduce_driver = jit.JitDriver(
+    greens=['shapelen', 'sig'],
+    virtualizables=['frame'],
+    reds=['self','arr', 'identity', 'frame'],
+    name='numpy_axisreduce',
+    get_printable_location=new_printable_location('axisreduce'),
+)
+
+
 class W_Ufunc(Wrappable):
     _attrs_ = ["name", "promote_to_float", "promote_bools", "identity"]
     _immutable_fields_ = ["promote_to_float", "promote_bools", "name"]
@@ -49,18 +58,72 @@
             )
         return self.call(space, __args__.arguments_w)
 
-    def descr_reduce(self, space, w_obj):
-        return self.reduce(space, w_obj, multidim=False)
+    def descr_reduce(self, space, w_obj, w_dim=0):
+        """reduce(...)
+        reduce(a, axis=0)
 
-    def reduce(self, space, w_obj, multidim):
-        from pypy.module.micronumpy.interp_numarray import convert_to_array, Scalar
-        
+        Reduces `a`'s dimension by one, by applying ufunc along one axis.
+
+        Let :math:`a.shape = (N_0, ..., N_i, ..., N_{M-1})`.  Then
+        :math:`ufunc.reduce(a, axis=i)[k_0, ..,k_{i-1}, k_{i+1}, .., k_{M-1}]` =
+        the result of iterating `j` over :math:`range(N_i)`, cumulatively applying
+        ufunc to each :math:`a[k_0, ..,k_{i-1}, j, k_{i+1}, .., k_{M-1}]`.
+        For a one-dimensional array, reduce produces results equivalent to:
+        ::
+
+         r = op.identity # op = ufunc
+         for i in xrange(len(A)):
+           r = op(r, A[i])
+         return r
+
+        For example, add.reduce() is equivalent to sum().
+
+        Parameters
+        ----------
+        a : array_like
+            The array to act on.
+        axis : int, optional
+            The axis along which to apply the reduction.
+
+        Examples
+        --------
+        >>> np.multiply.reduce([2,3,5])
+        30
+
+        A multi-dimensional array example:
+
+        >>> X = np.arange(8).reshape((2,2,2))
+        >>> X
+        array([[[0, 1],
+                [2, 3]],
+               [[4, 5],
+                [6, 7]]])
+        >>> np.add.reduce(X, 0)
+        array([[ 4,  6],
+               [ 8, 10]])
+        >>> np.add.reduce(X) # confirm: default axis value is 0
+        array([[ 4,  6],
+               [ 8, 10]])
+        >>> np.add.reduce(X, 1)
+        array([[ 2,  4],
+               [10, 12]])
+        >>> np.add.reduce(X, 2)
+        array([[ 1,  5],
+               [ 9, 13]])
+        """
+        return self.reduce(space, w_obj, False, False, w_dim)
+
+    def reduce(self, space, w_obj, multidim, promote_to_largest, w_dim):
+        from pypy.module.micronumpy.interp_numarray import convert_to_array, \
+                                                           Scalar
         if self.argcount != 2:
             raise OperationError(space.w_ValueError, space.wrap("reduce only "
                 "supported for binary functions"))
-
+        dim = space.int_w(w_dim)
         assert isinstance(self, W_Ufunc2)
         obj = convert_to_array(space, w_obj)
+        if dim >= len(obj.shape):
+            raise OperationError(space.w_ValueError, space.wrap("axis(=%d) out of bounds" % dim))
         if isinstance(obj, Scalar):
             raise OperationError(space.w_TypeError, space.wrap("cannot reduce "
                 "on a scalar"))
@@ -68,26 +131,80 @@
         size = obj.size
         dtype = find_unaryop_result_dtype(
             space, obj.find_dtype(),
-            promote_to_largest=True
+            promote_to_float=self.promote_to_float,
+            promote_to_largest=promote_to_largest,
+            promote_bools=True
         )
         shapelen = len(obj.shape)
+        if self.identity is None and size == 0:
+            raise operationerrfmt(space.w_ValueError, "zero-size array to "
+                    "%s.reduce without identity", self.name)
+        if shapelen > 1 and dim >= 0:
+            res = self.do_axis_reduce(obj, dtype, dim)
+            return space.wrap(res)
+        scalarsig = ScalarSignature(dtype)
         sig = find_sig(ReduceSignature(self.func, self.name, dtype,
-                                       ScalarSignature(dtype),
-                                       obj.create_sig(obj.shape)), obj)
+                                       scalarsig,
+                                       obj.create_sig()), obj)
         frame = sig.create_frame(obj)
-        if shapelen > 1 and not multidim:
-            raise OperationError(space.w_NotImplementedError,
-                space.wrap("not implemented yet"))
         if self.identity is None:
-            if size == 0:
-                raise operationerrfmt(space.w_ValueError, "zero-size array to "
-                    "%s.reduce without identity", self.name)
             value = sig.eval(frame, obj).convert_to(dtype)
             frame.next(shapelen)
         else:
             value = self.identity.convert_to(dtype)
         return self.reduce_loop(shapelen, sig, frame, value, obj, dtype)
 
+    def do_axis_reduce(self, obj, dtype, dim):
+        from pypy.module.micronumpy.interp_numarray import AxisReduce,\
+             W_NDimArray
+        
+        shape = obj.shape[0:dim] + obj.shape[dim + 1:len(obj.shape)]
+        size = 1
+        for s in shape:
+            size *= s
+        result = W_NDimArray(size, shape, dtype)
+        rightsig = obj.create_sig()
+        # note - this is just a wrapper so signature can fetch
+        #        both left and right, nothing more, especially
+        #        this is not a true virtual array, because shapes
+        #        don't quite match
+        arr = AxisReduce(self.func, self.name, obj.shape, dtype,
+                         result, obj, dim)
+        scalarsig = ScalarSignature(dtype)
+        sig = find_sig(AxisReduceSignature(self.func, self.name, dtype,
+                                           scalarsig, rightsig), arr)
+        assert isinstance(sig, AxisReduceSignature)
+        frame = sig.create_frame(arr)
+        shapelen = len(obj.shape)
+        if self.identity is not None:
+            identity = self.identity.convert_to(dtype)
+        else:
+            identity = None
+        self.reduce_axis_loop(frame, sig, shapelen, arr, identity)
+        return result
+
+    def reduce_axis_loop(self, frame, sig, shapelen, arr, identity):
+        # note - we can be advanterous here, depending on the exact field
+        # layout. For now let's say we iterate the original way and
+        # simply follow the original iteration order
+        while not frame.done():
+            axisreduce_driver.jit_merge_point(frame=frame, self=self,
+                                              sig=sig,
+                                              identity=identity,
+                                              shapelen=shapelen, arr=arr)
+            iter = frame.get_final_iter()
+            v = sig.eval(frame, arr).convert_to(sig.calc_dtype)
+            if iter.first_line:
+                if identity is not None:
+                    value = self.func(sig.calc_dtype, identity, v)
+                else:
+                    value = v
+            else:
+                cur = arr.left.getitem(iter.offset)
+                value = self.func(sig.calc_dtype, cur, v)
+            arr.left.setitem(iter.offset, value)
+            frame.next(shapelen)
+
     def reduce_loop(self, shapelen, sig, frame, value, obj, dtype):
         while not frame.done():
             reduce_driver.jit_merge_point(sig=sig,
@@ -95,10 +212,12 @@
                                           value=value, obj=obj, frame=frame,
                                           dtype=dtype)
             assert isinstance(sig, ReduceSignature)
-            value = sig.binfunc(dtype, value, sig.eval(frame, obj).convert_to(dtype))
+            value = sig.binfunc(dtype, value,
+                                sig.eval(frame, obj).convert_to(dtype))
             frame.next(shapelen)
         return value
 
+
 class W_Ufunc1(W_Ufunc):
     argcount = 1
 
@@ -183,6 +302,7 @@
     reduce = interp2app(W_Ufunc.descr_reduce),
 )
 
+
 def find_binop_result_dtype(space, dt1, dt2, promote_to_float=False,
     promote_bools=False):
     # dt1.num should be <= dt2.num
@@ -231,6 +351,7 @@
             dtypenum += 3
         return interp_dtype.get_dtype_cache(space).builtin_dtypes[dtypenum]
 
+
 def find_unaryop_result_dtype(space, dt, promote_to_float=False,
     promote_bools=False, promote_to_largest=False):
     if promote_bools and (dt.kind == interp_dtype.BOOLLTR):
@@ -255,6 +376,7 @@
             assert False
     return dt
 
+
 def find_dtype_for_scalar(space, w_obj, current_guess=None):
     bool_dtype = interp_dtype.get_dtype_cache(space).w_booldtype
     long_dtype = interp_dtype.get_dtype_cache(space).w_longdtype
@@ -348,7 +470,8 @@
 
         identity = extra_kwargs.get("identity")
         if identity is not None:
-            identity = interp_dtype.get_dtype_cache(space).w_longdtype.box(identity)
+            identity = \
+                 interp_dtype.get_dtype_cache(space).w_longdtype.box(identity)
         extra_kwargs["identity"] = identity
 
         func = ufunc_dtype_caller(space, ufunc_name, op_name, argcount,
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -1,10 +1,32 @@
 from pypy.rlib.objectmodel import r_dict, compute_identity_hash, compute_hash
 from pypy.rlib.rarithmetic import intmask
 from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \
-     OneDimIterator, ConstantIterator
-from pypy.module.micronumpy.strides import calculate_slice_strides
+     ConstantIterator, AxisIterator, ViewTransform,\
+     BroadcastTransform
 from pypy.rlib.jit import hint, unroll_safe, promote
 
+""" Signature specifies both the numpy expression that has been constructed
+and the assembler to be compiled. This is a very important observation -
+Two expressions will be using the same assembler if and only if they are
+compiled to the same signature.
+
+This is also a very convinient tool for specializations. For example
+a + a and a + b (where a != b) will compile to different assembler because
+we specialize on the same array access.
+
+When evaluating, signatures will create iterators per signature node,
+potentially sharing some of them. Iterators depend also on the actual
+expression, they're not only dependant on the array itself. For example
+a + b where a is dim 2 and b is dim 1 would create a broadcasted iterator for
+the array b.
+
+Such iterator changes are called Transformations. An actual iterator would
+be a combination of array and various transformation, like view, broadcast,
+dimension swapping etc.
+
+See interp_iter for transformations
+"""
+
 def new_printable_location(driver_name):
     def get_printable_location(shapelen, sig):
         return 'numpy ' + sig.debug_repr() + ' [%d dims,%s]' % (shapelen, driver_name)
@@ -33,7 +55,8 @@
         return sig
 
 class NumpyEvalFrame(object):
-    _virtualizable2_ = ['iterators[*]', 'final_iter', 'arraylist[*]']
+    _virtualizable2_ = ['iterators[*]', 'final_iter', 'arraylist[*]',
+                        'value', 'identity']
 
     @unroll_safe
     def __init__(self, iterators, arrays):
@@ -51,7 +74,7 @@
     def done(self):
         final_iter = promote(self.final_iter)
         if final_iter < 0:
-            return False
+            assert False
         return self.iterators[final_iter].done()
 
     @unroll_safe
@@ -59,6 +82,12 @@
         for i in range(len(self.iterators)):
             self.iterators[i] = self.iterators[i].next(shapelen)
 
+    def get_final_iter(self):
+        final_iter = promote(self.final_iter)
+        if final_iter < 0:
+            assert False
+        return self.iterators[final_iter]
+
 def _add_ptr_to_cache(ptr, cache):
     i = 0
     for p in cache:
@@ -70,6 +99,9 @@
         cache.append(ptr)
         return res
 
+def new_cache():
+    return r_dict(sigeq_no_numbering, sighash)
+
 class Signature(object):
     _attrs_ = ['iter_no', 'array_no']
     _immutable_fields_ = ['iter_no', 'array_no']
@@ -78,7 +110,7 @@
     iter_no = 0
 
     def invent_numbering(self):
-        cache = r_dict(sigeq_no_numbering, sighash)
+        cache = new_cache()
         allnumbers = []
         self._invent_numbering(cache, allnumbers)
 
@@ -95,13 +127,13 @@
             allnumbers.append(no)
         self.iter_no = no
 
-    def create_frame(self, arr, res_shape=None):
-        res_shape = res_shape or arr.shape
+    def create_frame(self, arr):
         iterlist = []
         arraylist = []
-        self._create_iter(iterlist, arraylist, arr, res_shape, [])
+        self._create_iter(iterlist, arraylist, arr, [])
         return NumpyEvalFrame(iterlist, arraylist)
 
+
 class ConcreteSignature(Signature):
     _immutable_fields_ = ['dtype']
 
@@ -120,16 +152,6 @@
     def hash(self):
         return compute_identity_hash(self.dtype)
 
-    def allocate_view_iter(self, arr, res_shape, chunklist):
-        r = arr.shape, arr.start, arr.strides, arr.backstrides
-        if chunklist:
-            for chunkelem in chunklist:
-                r = calculate_slice_strides(r[0], r[1], r[2], r[3], chunkelem)
-        shape, start, strides, backstrides = r
-        if len(res_shape) == 1:
-            return OneDimIterator(start, strides[0], res_shape[0])
-        return ViewIterator(start, strides, backstrides, shape, res_shape)
-
 class ArraySignature(ConcreteSignature):
     def debug_repr(self):
         return 'Array'
@@ -141,22 +163,21 @@
         # is not of a concrete class it means that we have a _forced_result,
         # otherwise the signature would not match
         assert isinstance(concr, ConcreteArray)
+        assert concr.dtype is self.dtype
         self.array_no = _add_ptr_to_cache(concr.storage, cache)
 
-    def _create_iter(self, iterlist, arraylist, arr, res_shape, chunklist):
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
         from pypy.module.micronumpy.interp_numarray import ConcreteArray
         concr = arr.get_concrete()
         assert isinstance(concr, ConcreteArray)
         storage = concr.storage
         if self.iter_no >= len(iterlist):
-            iterlist.append(self.allocate_iter(concr, res_shape, chunklist))
+            iterlist.append(self.allocate_iter(concr, transforms))
         if self.array_no >= len(arraylist):
             arraylist.append(storage)
 
-    def allocate_iter(self, arr, res_shape, chunklist):
-        if chunklist:
-            return self.allocate_view_iter(arr, res_shape, chunklist)
-        return ArrayIterator(arr.size)
+    def allocate_iter(self, arr, transforms):
+        return ArrayIterator(arr.size).apply_transformations(arr, transforms)
 
     def eval(self, frame, arr):
         iter = frame.iterators[self.iter_no]
@@ -169,7 +190,7 @@
     def _invent_array_numbering(self, arr, cache):
         pass
 
-    def _create_iter(self, iterlist, arraylist, arr, res_shape, chunklist):
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
         if self.iter_no >= len(iterlist):
             iter = ConstantIterator()
             iterlist.append(iter)
@@ -189,8 +210,9 @@
         allnumbers.append(no)
         self.iter_no = no
 
-    def allocate_iter(self, arr, res_shape, chunklist):
-        return self.allocate_view_iter(arr, res_shape, chunklist)
+    def allocate_iter(self, arr, transforms):
+        return ViewIterator(arr.start, arr.strides, arr.backstrides,
+                            arr.shape).apply_transformations(arr, transforms)
 
 class VirtualSliceSignature(Signature):
     def __init__(self, child):
@@ -201,6 +223,9 @@
         assert isinstance(arr, VirtualSlice)
         self.child._invent_array_numbering(arr.child, cache)
 
+    def _invent_numbering(self, cache, allnumbers):
+        self.child._invent_numbering(new_cache(), allnumbers)
+
     def hash(self):
         return intmask(self.child.hash() ^ 1234)
 
@@ -210,12 +235,11 @@
         assert isinstance(other, VirtualSliceSignature)
         return self.child.eq(other.child, compare_array_no)
 
-    def _create_iter(self, iterlist, arraylist, arr, res_shape, chunklist):
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
         from pypy.module.micronumpy.interp_numarray import VirtualSlice
         assert isinstance(arr, VirtualSlice)
-        chunklist.append(arr.chunks)
-        self.child._create_iter(iterlist, arraylist, arr.child, res_shape,
-                                chunklist)
+        transforms = transforms + [ViewTransform(arr.chunks)]
+        self.child._create_iter(iterlist, arraylist, arr.child, transforms)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import VirtualSlice
@@ -251,11 +275,10 @@
         assert isinstance(arr, Call1)
         self.child._invent_array_numbering(arr.values, cache)
 
-    def _create_iter(self, iterlist, arraylist, arr, res_shape, chunklist):
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
         from pypy.module.micronumpy.interp_numarray import Call1
         assert isinstance(arr, Call1)
-        self.child._create_iter(iterlist, arraylist, arr.values, res_shape,
-                                chunklist)
+        self.child._create_iter(iterlist, arraylist, arr.values, transforms)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import Call1
@@ -296,29 +319,68 @@
         self.left._invent_numbering(cache, allnumbers)
         self.right._invent_numbering(cache, allnumbers)
 
-    def _create_iter(self, iterlist, arraylist, arr, res_shape, chunklist):
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
         from pypy.module.micronumpy.interp_numarray import Call2
 
         assert isinstance(arr, Call2)
-        self.left._create_iter(iterlist, arraylist, arr.left, res_shape,
-                               chunklist)
-        self.right._create_iter(iterlist, arraylist, arr.right, res_shape,
-                                chunklist)
+        self.left._create_iter(iterlist, arraylist, arr.left, transforms)
+        self.right._create_iter(iterlist, arraylist, arr.right, transforms)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import Call2
         assert isinstance(arr, Call2)
         lhs = self.left.eval(frame, arr.left).convert_to(self.calc_dtype)
         rhs = self.right.eval(frame, arr.right).convert_to(self.calc_dtype)
+        
         return self.binfunc(self.calc_dtype, lhs, rhs)
 
     def debug_repr(self):
         return 'Call2(%s, %s, %s)' % (self.name, self.left.debug_repr(),
                                       self.right.debug_repr())
 
+class BroadcastLeft(Call2):
+    def _invent_numbering(self, cache, allnumbers):
+        self.left._invent_numbering(new_cache(), allnumbers)
+        self.right._invent_numbering(cache, allnumbers)
+    
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        from pypy.module.micronumpy.interp_numarray import Call2
+
+        assert isinstance(arr, Call2)
+        ltransforms = transforms + [BroadcastTransform(arr.shape)]
+        self.left._create_iter(iterlist, arraylist, arr.left, ltransforms)
+        self.right._create_iter(iterlist, arraylist, arr.right, transforms)
+
+class BroadcastRight(Call2):
+    def _invent_numbering(self, cache, allnumbers):
+        self.left._invent_numbering(cache, allnumbers)
+        self.right._invent_numbering(new_cache(), allnumbers)
+
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        from pypy.module.micronumpy.interp_numarray import Call2
+
+        assert isinstance(arr, Call2)
+        rtransforms = transforms + [BroadcastTransform(arr.shape)]
+        self.left._create_iter(iterlist, arraylist, arr.left, transforms)
+        self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)
+
+class BroadcastBoth(Call2):
+    def _invent_numbering(self, cache, allnumbers):
+        self.left._invent_numbering(new_cache(), allnumbers)
+        self.right._invent_numbering(new_cache(), allnumbers)
+
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        from pypy.module.micronumpy.interp_numarray import Call2
+
+        assert isinstance(arr, Call2)
+        rtransforms = transforms + [BroadcastTransform(arr.shape)]
+        ltransforms = transforms + [BroadcastTransform(arr.shape)]
+        self.left._create_iter(iterlist, arraylist, arr.left, ltransforms)
+        self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)
+
 class ReduceSignature(Call2):
-    def _create_iter(self, iterlist, arraylist, arr, res_shape, chunklist):
-        self.right._create_iter(iterlist, arraylist, arr, res_shape, chunklist)
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        self.right._create_iter(iterlist, arraylist, arr, transforms)
 
     def _invent_numbering(self, cache, allnumbers):
         self.right._invent_numbering(cache, allnumbers)
@@ -328,3 +390,63 @@
 
     def eval(self, frame, arr):
         return self.right.eval(frame, arr)
+
+    def debug_repr(self):
+        return 'ReduceSig(%s, %s)' % (self.name, self.right.debug_repr())
+
+class SliceloopSignature(Call2):
+    def eval(self, frame, arr):
+        from pypy.module.micronumpy.interp_numarray import Call2
+        
+        assert isinstance(arr, Call2)
+        ofs = frame.iterators[0].offset
+        arr.left.setitem(ofs, self.right.eval(frame, arr.right).convert_to(
+            self.calc_dtype))
+    
+    def debug_repr(self):
+        return 'SliceLoop(%s, %s, %s)' % (self.name, self.left.debug_repr(),
+                                          self.right.debug_repr())
+
+class SliceloopBroadcastSignature(SliceloopSignature):
+    def _invent_numbering(self, cache, allnumbers):
+        self.left._invent_numbering(new_cache(), allnumbers)
+        self.right._invent_numbering(cache, allnumbers)
+
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        from pypy.module.micronumpy.interp_numarray import SliceArray
+
+        assert isinstance(arr, SliceArray)
+        rtransforms = transforms + [BroadcastTransform(arr.shape)]
+        self.left._create_iter(iterlist, arraylist, arr.left, transforms)
+        self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)
+
+class AxisReduceSignature(Call2):
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        from pypy.module.micronumpy.interp_numarray import AxisReduce,\
+             ConcreteArray
+
+        assert isinstance(arr, AxisReduce)
+        left = arr.left
+        assert isinstance(left, ConcreteArray)
+        iterlist.append(AxisIterator(left.start, arr.dim, arr.shape,
+                                     left.strides, left.backstrides))
+        self.right._create_iter(iterlist, arraylist, arr.right, transforms)
+
+    def _invent_numbering(self, cache, allnumbers):
+        allnumbers.append(0)
+        self.right._invent_numbering(cache, allnumbers)
+
+    def _invent_array_numbering(self, arr, cache):
+        from pypy.module.micronumpy.interp_numarray import AxisReduce
+
+        assert isinstance(arr, AxisReduce)
+        self.right._invent_array_numbering(arr.right, cache)
+
+    def eval(self, frame, arr):
+        from pypy.module.micronumpy.interp_numarray import AxisReduce
+
+        assert isinstance(arr, AxisReduce)
+        return self.right.eval(frame, arr.right).convert_to(self.calc_dtype)
+    
+    def debug_repr(self):
+        return 'AxisReduceSig(%s, %s)' % (self.name, self.right.debug_repr())
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -246,6 +246,10 @@
         c = b.copy()
         assert (c == b).all()
 
+        a = arange(15).reshape(5,3)
+        b = a.copy()
+        assert (b == a).all()
+
     def test_iterator_init(self):
         from _numpypy import array
         a = array(range(5))
@@ -720,10 +724,15 @@
         assert d[1] == 12
 
     def test_mean(self):
-        from _numpypy import array
+        from _numpypy import array, mean
         a = array(range(5))
         assert a.mean() == 2.0
         assert a[:4].mean() == 1.5
+        a = array(range(105)).reshape(3, 5, 7)
+        b = mean(a, axis=0)
+        b[0,0]==35.
+        assert (b == array(range(35, 70), dtype=float).reshape(5, 7)).all()
+        assert (mean(a, 2) == array(range(0, 15), dtype=float).reshape(3, 5) * 7 + 3).all()
 
     def test_sum(self):
         from _numpypy import array
@@ -734,6 +743,32 @@
         a = array([True] * 5, bool)
         assert a.sum() == 5
 
+        raises(TypeError, 'a.sum(2, 3)')
+
+    def test_reduce_nd(self):
+        from numpypy import arange, array, multiply
+        a = arange(15).reshape(5, 3)
+        assert a.sum() == 105
+        assert a.max() == 14
+        assert array([]).sum() == 0.0
+        raises(ValueError, 'array([]).max()')
+        assert (a.sum(0) == [30, 35, 40]).all()
+        assert (a.sum(1) == [3, 12, 21, 30, 39]).all()
+        assert (a.max(0) == [12, 13, 14]).all()
+        assert (a.max(1) == [2, 5, 8, 11, 14]).all()
+        assert ((a + a).max() == 28)
+        assert ((a + a).max(0) == [24, 26, 28]).all()
+        assert ((a + a).sum(1) == [6, 24, 42, 60, 78]).all()
+        assert (multiply.reduce(a) == array([0, 3640, 12320])).all()
+        a = array(range(105)).reshape(3, 5, 7)
+        assert (a[:, 1, :].sum(0) == [126, 129, 132, 135, 138, 141, 144]).all()
+        assert (a[:, 1, :].sum(1) == [70, 315, 560]).all()
+        raises (ValueError, 'a[:, 1, :].sum(2)')
+        assert ((a + a).T.sum(2).T == (a + a).sum(0)).all()
+        skip("Those are broken on reshape, fix!")
+        assert (a.reshape(1,-1).sum(0) == range(105)).all()
+        assert (a.reshape(1,-1).sum(1) == 5460)
+
     def test_identity(self):
         from _numpypy import identity, array
         from _numpypy import int32, float64, dtype
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -298,7 +298,7 @@
         for i in range(len(a)):
             assert b[i] == math.atan(a[i])
 
-        a  = array([float('nan')])
+        a = array([float('nan')])
         b = arctan(a)
         assert math.isnan(b[0])
 
@@ -336,9 +336,9 @@
         from _numpypy import sin, add
 
         raises(ValueError, sin.reduce, [1, 2, 3])
-        raises(TypeError, add.reduce, 1)
+        raises(ValueError, add.reduce, 1)
 
-    def test_reduce(self):
+    def test_reduce_1d(self):
         from _numpypy import add, maximum
 
         assert add.reduce([1, 2, 3]) == 6
@@ -346,6 +346,12 @@
         assert maximum.reduce([1, 2, 3]) == 3
         raises(ValueError, maximum.reduce, [])
 
+    def test_reduceND(self):
+        from numpypy import add, arange
+        a = arange(12).reshape(3, 4)
+        assert (add.reduce(a, 0) == [12, 15, 18, 21]).all()
+        assert (add.reduce(a, 1) == [6.0, 22.0, 38.0]).all()
+
     def test_comparisons(self):
         import operator
         from _numpypy import equal, not_equal, less, less_equal, greater, greater_equal
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -47,6 +47,8 @@
         def f(i):
             interp = InterpreterState(codes[i])
             interp.run(space)
+            if not len(interp.results):
+                raise Exception("need results")
             w_res = interp.results[-1]
             if isinstance(w_res, BaseArray):
                 concr = w_res.get_concrete_or_scalar()
@@ -115,6 +117,28 @@
                                 "int_add": 1, "int_ge": 1, "guard_false": 1,
                                 "jump": 1, 'arraylen_gc': 1})
 
+    def define_axissum():
+        return """
+        a = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+        b = sum(a,0)
+        b -> 1
+        """
+
+    def test_axissum(self):
+        result = self.run("axissum")
+        assert result == 30
+        # XXX note - the bridge here is fairly crucial and yet it's pretty
+        #            bogus. We need to improve the situation somehow.
+        self.check_simple_loop({'getinteriorfield_raw': 2,
+                                'setinteriorfield_raw': 1,
+                                'arraylen_gc': 1,
+                                'guard_true': 1,
+                                'int_lt': 1,
+                                'jump': 1,
+                                'float_add': 1,
+                                'int_add': 3,
+                                })
+
     def define_prod():
         return """
         a = |30|
@@ -193,9 +217,9 @@
         # This is the sum of the ops for both loops, however if you remove the
         # optimization then you end up with 2 float_adds, so we can still be
         # sure it was optimized correctly.
-        self.check_resops({'setinteriorfield_raw': 4, 'getfield_gc': 26,
+        self.check_resops({'setinteriorfield_raw': 4, 'getfield_gc': 22,
                            'getarrayitem_gc': 4, 'getarrayitem_gc_pure': 2,
-                           'getfield_gc_pure': 4,
+                           'getfield_gc_pure': 8,
                            'guard_class': 8, 'int_add': 8, 'float_mul': 2,
                            'jump': 2, 'int_ge': 4,
                            'getinteriorfield_raw': 4, 'float_add': 2,
@@ -212,7 +236,8 @@
     def test_ufunc(self):
         result = self.run("ufunc")
         assert result == -6
-        self.check_simple_loop({"getinteriorfield_raw": 2, "float_add": 1, "float_neg": 1,
+        self.check_simple_loop({"getinteriorfield_raw": 2, "float_add": 1,
+                                "float_neg": 1,
                                 "setinteriorfield_raw": 1, "int_add": 2,
                                 "int_ge": 1, "guard_false": 1, "jump": 1,
                                 'arraylen_gc': 1})
@@ -322,10 +347,9 @@
         result = self.run("setslice")
         assert result == 11.0
         self.check_trace_count(1)
-        self.check_simple_loop({'getinteriorfield_raw': 2, 'float_add' : 1,
-                                'setinteriorfield_raw': 1, 'int_add': 3,
-                                'int_lt': 1, 'guard_true': 1, 'jump': 1,
-                                'arraylen_gc': 3})
+        self.check_simple_loop({'getinteriorfield_raw': 2, 'float_add': 1,
+                                'setinteriorfield_raw': 1, 'int_add': 2,
+                                'int_eq': 1, 'guard_false': 1, 'jump': 1})
 
     def define_virtual_slice():
         return """
@@ -339,11 +363,12 @@
         result = self.run("virtual_slice")
         assert result == 4
         self.check_trace_count(1)
-        self.check_simple_loop({'getinteriorfield_raw': 2, 'float_add' : 1,
+        self.check_simple_loop({'getinteriorfield_raw': 2, 'float_add': 1,
                                 'setinteriorfield_raw': 1, 'int_add': 2,
                                 'int_ge': 1, 'guard_false': 1, 'jump': 1,
                                 'arraylen_gc': 1})
 
+
 class TestNumpyOld(LLJitMixin):
     def setup_class(cls):
         py.test.skip("old")
@@ -377,4 +402,3 @@
 
         result = self.meta_interp(f, [5], listops=True, backendopt=True)
         assert result == f(5)
-


More information about the pypy-commit mailing list