[pypy-commit] pypy numpypy-out: add BroadcastUfunc iter, more tests pass

mattip noreply at buildbot.pypy.org
Sun Feb 19 00:36:49 CET 2012


Author: mattip
Branch: numpypy-out
Changeset: r52623:b3836fce3c20
Date: 2012-02-18 21:34 +0200
http://bitbucket.org/pypy/pypy/changeset/b3836fce3c20/

Log:	add BroadcastUfunc iter, more tests pass

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -786,7 +786,11 @@
         if self.forced_result is not None:
             return self.forced_result.create_sig()
         if self.shape != self.values.shape:
-            xxx 
+            #This happens if out arg is used
+            return signature.BroadcastUfunc(self.ufunc, self.name,
+                                            self.calc_dtype,
+                                            self.values.create_sig(),
+                                            self.res.create_sig())
         return signature.Call1(self.ufunc, self.name, self.calc_dtype,
                                self.values.create_sig())
 
@@ -837,7 +841,8 @@
         if res is None:
             res = W_NDimArray(size, shape, dtype, order)
         assert isinstance(res, BaseArray)
-        Call2.__init__(self, None, 'assign', shape, dtype, dtype, res, child)
+        concr = res.get_concrete()
+        Call2.__init__(self, None, 'assign', shape, dtype, dtype, concr, child)
 
     def create_sig(self):
         sig = signature.ResultSignature(self.res_dtype, self.left.create_sig(),
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -216,13 +216,14 @@
         return self.child.eval(frame, arr.child)
 
 class Call1(Signature):
-    _immutable_fields_ = ['unfunc', 'name', 'child', 'dtype']
+    _immutable_fields_ = ['unfunc', 'name', 'child', 'res', 'dtype']
 
-    def __init__(self, func, name, dtype, child):
+    def __init__(self, func, name, dtype, child, res=None):
         self.unfunc = func
         self.child = child
         self.name = name
         self.dtype = dtype
+        self.res  = res
 
     def hash(self):
         return compute_hash(self.name) ^ intmask(self.child.hash() << 1)
@@ -256,6 +257,29 @@
         v = self.child.eval(frame, arr.values).convert_to(arr.calc_dtype)
         return self.unfunc(arr.calc_dtype, v)
 
+
+class BroadcastUfunc(Call1):
+    def _invent_numbering(self, cache, allnumbers):
+        self.res._invent_numbering(cache, allnumbers)
+        self.child._invent_numbering(new_cache(), allnumbers)
+
+    def debug_repr(self):
+        return 'BroadcastUfunc(%s, %s)' % (self.name, self.child.debug_repr())
+
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        from pypy.module.micronumpy.interp_numarray import Call1
+
+        assert isinstance(arr, Call1)
+        vtransforms = transforms + [BroadcastTransform(arr.values.shape)]
+        self.child._create_iter(iterlist, arraylist, arr.values, vtransforms)
+        self.res._create_iter(iterlist, arraylist, arr.res, transforms)
+
+    def eval(self, frame, arr):
+        from pypy.module.micronumpy.interp_numarray import Call1
+        assert isinstance(arr, Call1)
+        v = self.child.eval(frame, arr.values).convert_to(arr.calc_dtype)
+        return self.unfunc(arr.calc_dtype, v)
+
 class Call2(Signature):
     _immutable_fields_ = ['binfunc', 'name', 'calc_dtype', 'left', 'right']
 
diff --git a/pypy/module/micronumpy/test/test_outarg.py b/pypy/module/micronumpy/test/test_outarg.py
--- a/pypy/module/micronumpy/test/test_outarg.py
+++ b/pypy/module/micronumpy/test/test_outarg.py
@@ -33,6 +33,7 @@
 
     def test_ufunc_out(self):
         from _numpypy import array, negative, zeros, sin
+        from math import sin as msin
         a = array([[1, 2], [3, 4]])
         c = zeros((2,2,2))
         b = negative(a + a, out=c[1])
@@ -48,7 +49,7 @@
         assert b.shape == c.shape
         a = array([1, 2])
         b = sin(a, out=c)
-        assert(c == [[-1, -2], [-1, -2]]).all()
+        assert(c == [[msin(1), msin(2)]] * 2).all()
         b = sin(a, out=c+c)
         assert (c == b).all()
 


More information about the pypy-commit mailing list