[pypy-svn] pypy numpy-exp: Implemented subtraction in numpy (thanks to brentp). Also refactored to metaprogram (this is why we use pypy).

alex_gaynor commits-noreply at bitbucket.org
Thu May 5 00:33:29 CEST 2011


Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: numpy-exp
Changeset: r43896:52cadacd06c8
Date: 2011-05-04 18:33 -0400
http://bitbucket.org/pypy/pypy/changeset/52cadacd06c8/

Log:	Implemented subtraction in numpy (thanks to brentp). Also
	refactored to metaprogram (this is why we use pypy).

diff --git a/pypy/module/micronumpy/numarray.py b/pypy/module/micronumpy/numarray.py
--- a/pypy/module/micronumpy/numarray.py
+++ b/pypy/module/micronumpy/numarray.py
@@ -1,10 +1,12 @@
 from pypy.interpreter.baseobjspace import ObjSpace, W_Root, Wrappable
 from pypy.interpreter.error import operationerrfmt
+from pypy.interpreter.gateway import interp2app, unwrap_spec
 from pypy.interpreter.typedef import TypeDef
-from pypy.interpreter.gateway import interp2app, unwrap_spec
-from pypy.rpython.lltypesystem import lltype
 from pypy.rlib import jit
 from pypy.rlib.nonconst import NonConstant
+from pypy.rpython.lltypesystem import lltype
+from pypy.tool.sourcetools import func_with_new_name
+
 
 TP = lltype.Array(lltype.Float, hints={'nolength': True})
 
@@ -117,13 +119,18 @@
                 frame.pushvalue(val)
             elif opcode == 'a':
                 # Add.
+                a = frame.popvalue()
                 b = frame.popvalue()
+                frame.pushvalue(a + b)
+            elif opcode == 's':
+                # Subtract
                 a = frame.popvalue()
-                frame.pushvalue(a + b)
+                b = frame.popvalue()
+                frame.pushvalue(a - b)
             elif opcode == 'm':
                 # Multiply.
+                a = frame.popvalue()
                 b = frame.popvalue()
-                a = frame.popvalue()
                 frame.pushvalue(a * b)
             else:
                 raise NotImplementedError(
@@ -145,19 +152,21 @@
         # (we still have to compile new bytecode, but too bad)
         return compute(code)
 
-    def descr_add(self, space, w_other):
-        if isinstance(w_other, BaseArray):
-            return space.wrap(BinOp('a', self, w_other))
-        else:
-            return space.wrap(BinOp('a', self,
-                FloatWrapper(space.float_w(w_other))))
+    def _binop_impl(bytecode):
+        def impl(self, space, w_other):
+            if isinstance(w_other, BaseArray):
+                return space.wrap(BinOp(bytecode, self, w_other))
+            else:
+                return space.wrap(BinOp(
+                    bytecode,
+                    self,
+                    FloatWrapper(space.float_w(w_other))
+                ))
+        return func_with_new_name(impl, "binop_%s_impl" % bytecode)
 
-    def descr_mul(self, space, w_other):
-        if isinstance(w_other, BaseArray):
-            return space.wrap(BinOp('m', self, w_other))
-        else:
-            return space.wrap(BinOp('m', self,
-                FloatWrapper(space.float_w(w_other))))
+    descr_add = _binop_impl("a")
+    descr_mul = _binop_impl("m")
+    descr_sub = _binop_impl("s")
 
     def compile(self):
         raise NotImplementedError("abstract base class")
@@ -192,6 +201,7 @@
     'Operation',
     force = interp2app(BaseArray.force),
     __add__ = interp2app(BaseArray.descr_add),
+    __sub__ = interp2app(BaseArray.descr_sub),
     __mul__ = interp2app(BaseArray.descr_mul),
 )
 
@@ -251,6 +261,7 @@
     __getitem__ = interp2app(SingleDimArray.descr_getitem),
     __setitem__ = interp2app(SingleDimArray.descr_setitem),
     __add__ = interp2app(BaseArray.descr_add),
+    __sub__ = interp2app(BaseArray.descr_sub),
     __mul__ = interp2app(BaseArray.descr_mul),
     force = interp2app(SingleDimArray.force),
 )
\ No newline at end of file
diff --git a/pypy/module/micronumpy/test/test_numpy.py b/pypy/module/micronumpy/test/test_numpy.py
--- a/pypy/module/micronumpy/test/test_numpy.py
+++ b/pypy/module/micronumpy/test/test_numpy.py
@@ -44,6 +44,28 @@
         for i in range(5):
             assert b[i] == i + 5
 
+    def test_subtract(self):
+        from numpy import array
+        a = array(range(5))
+        b = (a - a).force()
+        for i in range(5):
+            assert b[i] == 0
+
+    def test_subtract_other(self):
+        from numpy import array
+        a = array(range(5))
+        b = array([1, 1, 1, 1, 1])
+        c = (a - b).force()
+        for i in range(5):
+            assert c[i] == i - 1
+
+    def test_subtract_constant(self):
+        from numpy import array
+        a = array(range(5))
+        b = (a - 5).force()
+        for i in range(5):
+            assert b[i] == i - 5
+
     def test_mul(self):
         from numpy import array
         a = array(range(5))
@@ -64,12 +86,12 @@
     def setup_class(cls):
         py.test.skip("unimplemented")
         cls.space = gettestobjspace(usemodules=('micronumpy',))
-    
+
     def test_zeroes(self):
         from numpy import zeros
         ar = zeros(3, dtype=int)
         assert ar[0] == 0
-    
+
     def test_setitem_getitem(self):
         from numpy import zeros
         ar = zeros(8, dtype=int)


More information about the Pypy-commit mailing list