[pypy-commit] pypy refactor-signature: sharing arrays

fijal noreply at buildbot.pypy.org
Fri Dec 16 21:40:35 CET 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: refactor-signature
Changeset: r50606:82d0ce07b964
Date: 2011-12-16 22:40 +0200
http://bitbucket.org/pypy/pypy/changeset/82d0ce07b964/

Log:	sharing arrays

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -708,7 +708,7 @@
     def find_sig(self):
         """ find a correct signature for the array
         """
-        return signature.find_sig(self.create_sig())
+        return signature.find_sig(self.create_sig(), self)
 
 def convert_to_array(space, w_obj):
     if isinstance(w_obj, BaseArray):
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -70,7 +70,7 @@
         shapelen = len(obj.shape)
         sig = find_sig(ReduceSignature(self.func, self.name, dtype,
                                        ScalarSignature(dtype),
-                                       obj.create_sig()))
+                                       obj.create_sig()), obj)
         frame = sig.create_frame(obj)
         if shapelen > 1 and not multidim:
             raise OperationError(space.w_NotImplementedError,
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -2,6 +2,7 @@
 from pypy.rlib.rarithmetic import intmask
 from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \
      BroadcastIterator, OneDimIterator, ConstantIterator
+from pypy.rpython.lltypesystem.llmemory import cast_ptr_to_adr
 from pypy.rlib.jit import hint, unroll_safe, promote
 
 # def components_eq(lhs, rhs):
@@ -22,12 +23,16 @@
 def sigeq(one, two):
     return one.eq(two)
 
+def sigeq2(one, two):
+    return one.eq(two, compare_array_no=False)
+
 def sighash(sig):
     return sig.hash()
 
 known_sigs = r_dict(sigeq, sighash)
 
-def find_sig(sig):
+def find_sig(sig, arr):
+    sig.invent_array_numbering(arr)
     try:
         return known_sigs[sig]
     except KeyError:
@@ -36,12 +41,13 @@
         return sig
 
 class NumpyEvalFrame(object):
-    _virtualizable2_ = ['iterators[*]', 'final_iter']
+    _virtualizable2_ = ['iterators[*]', 'final_iter', 'arraylist[*]']
 
     @unroll_safe
-    def __init__(self, iterators):
+    def __init__(self, iterators, arrays):
         self = hint(self, access_directly=True, fresh_virtualizable=True)
         self.iterators = iterators[:]
+        self.arrays = arrays[:]
         for i in range(len(self.iterators)):
             iter = self.iterators[i]
             if not isinstance(iter, ConstantIterator):# or not isinstance(iter, BroadcastIterator):
@@ -61,15 +67,33 @@
         for i in range(len(self.iterators)):
             self.iterators[i] = self.iterators[i].next(shapelen)
 
+def _add_ptr_to_cache(ptr, cache):
+    i = 0
+    for p in cache:
+        if ptr == p:
+            return i
+        i += 1
+    else:
+        res = len(cache)
+        cache.append(ptr)
+        return res
+
 class Signature(object):
-    _attrs_ = ['iter_no']
-    _immutable_fields_ = ['iter_no']
+    _attrs_ = ['iter_no', 'array_no']
+    _immutable_fields_ = ['iter_no', 'array_no']
+
+    array_no = 0
+    iter_no = 0
 
     def invent_numbering(self):
-        cache = r_dict(sigeq, sighash)
+        cache = r_dict(sigeq2, sighash)
         allnumbers = []
         self._invent_numbering(cache, allnumbers)
 
+    def invent_array_numbering(self, arr):
+        cache = []
+        self._invent_array_numbering(arr, cache)
+
     def _invent_numbering(self, cache, allnumbers):
         try:
             no = cache[self]
@@ -81,8 +105,9 @@
 
     def create_frame(self, arr):
         iterlist = []
-        self._create_iter(iterlist, arr)
-        return NumpyEvalFrame(iterlist)
+        arraylist = []
+        self._create_iter(iterlist, arraylist, arr)
+        return NumpyEvalFrame(iterlist, arraylist)
 
 class ConcreteSignature(Signature):
     _immutable_fields_ = ['dtype']
@@ -90,10 +115,13 @@
     def __init__(self, dtype):
         self.dtype = dtype
 
-    def eq(self, other):
+    def eq(self, other, compare_array_no=True):
         if type(self) is not type(other):
             return False
         assert isinstance(other, ConcreteSignature)
+        if compare_array_no:
+            if self.array_no != other.array_no:
+                return False
         return self.dtype is other.dtype
 
     def hash(self):
@@ -103,41 +131,59 @@
     def debug_repr(self):
         return 'Array'
 
-    def _create_iter(self, iterlist, arr):
+    def _invent_array_numbering(self, arr, cache):
+        from pypy.module.micronumpy.interp_numarray import W_NDimArray
+        assert isinstance(arr, W_NDimArray)
+        self.array_no = _add_ptr_to_cache(arr.storage, cache)
+
+    def _create_iter(self, iterlist, arraylist, arr):
         from pypy.module.micronumpy.interp_numarray import W_NDimArray
         assert isinstance(arr, W_NDimArray)
         if self.iter_no >= len(iterlist):
             iterlist.append(ArrayIterator(arr.size))
+        if self.array_no >= len(arraylist):
+            arraylist.append(arr.storage)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import W_NDimArray
         assert isinstance(arr, W_NDimArray)
         iter = frame.iterators[self.iter_no]
-        return self.dtype.getitem(arr.storage, iter.offset)
+        return self.dtype.getitem(frame.arrays[self.array_no], iter.offset)
 
 class ForcedSignature(ArraySignature):
     def debug_repr(self):
         return 'ForcedArray'
 
-    def _create_iter(self, iterlist, arr):
+    def _invent_array_numbering(self, arr, cache):
+        from pypy.module.micronumpy.interp_numarray import VirtualArray
+        assert isinstance(arr, VirtualArray)
+        arr = arr.forced_result
+        self.array_no = _add_ptr_to_cache(arr.storage, cache)
+
+    def _create_iter(self, iterlist, arraylist, arr):
         from pypy.module.micronumpy.interp_numarray import VirtualArray
         assert isinstance(arr, VirtualArray)
         arr = arr.forced_result
         if self.iter_no >= len(iterlist):
             iterlist.append(ArrayIterator(arr.size))
+        if self.array_no >= len(arraylist):
+            arraylist.append(arr.storage)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import VirtualArray
         assert isinstance(arr, VirtualArray)
         arr = arr.forced_result
         iter = frame.iterators[self.iter_no]
-        return self.dtype.getitem(arr.storage, iter.offset)    
+        return self.dtype.getitem(frame.arrays[self.array_no], iter.offset)    
 
 class ScalarSignature(ConcreteSignature):
     def debug_repr(self):
         return 'Scalar'
 
-    def _create_iter(self, iterlist, arr):
+    def _invent_array_numbering(self, arr, cache):
+        pass
+
+    def _create_iter(self, iterlist, arraylist, arr):
         if self.iter_no >= len(iterlist):
             iter = ConstantIterator()
             iterlist.append(iter)
@@ -153,11 +199,11 @@
     def __init__(self, child):
         self.child = child
     
-    def eq(self, other):
+    def eq(self, other, compare_array_no=True):
         if type(self) is not type(other):
             return False
         assert isinstance(other, ViewSignature)
-        return self.child.eq(other.child)
+        return self.child.eq(other.child, compare_array_no)
 
     def hash(self):
         return self.child.hash() ^ 0x12345
@@ -171,25 +217,33 @@
         allnumbers.append(no)
         self.iter_no = no
 
-    def _create_iter(self, iterlist, arr):
+    def _invent_array_numbering(self, arr, cache):
+        from pypy.module.micronumpy.interp_numarray import ConcreteViewArray
+        assert isinstance(arr, ConcreteViewArray)
+        self.array_no = _add_ptr_to_cache(arr.parent.storage, cache)
+
+    def _create_iter(self, iterlist, arraylist, arr):
         from pypy.module.micronumpy.interp_numarray import ConcreteViewArray
 
         assert isinstance(arr, ConcreteViewArray)
         if self.iter_no >= len(iterlist):
             iterlist.append(ViewIterator(arr))
+        if self.array_no >= len(arraylist):
+            arraylist.append(arr.parent.storage)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import W_NDimSlice
         assert isinstance(arr, W_NDimSlice)
         arr = arr.get_concrete()
         iter = frame.iterators[self.iter_no]
-        return arr.find_dtype().getitem(arr.parent.storage, iter.offset)
+        return arr.find_dtype().getitem(frame.arrays[self.array_no],
+                                        iter.offset)
 
 class FlatiterSignature(ViewSignature):
     def debug_repr(self):
         return 'FlatIter(%s)' % self.child.debug_repr()
 
-    def _create_iter(self, iterlist, arr):
+    def _create_iter(self, iterlist, arraylist, arr):
         raise NotImplementedError
 
 class Call1(Signature):
@@ -203,11 +257,12 @@
     def hash(self):
         return compute_hash(self.name) ^ intmask(self.child.hash() << 1)
 
-    def eq(self, other):
+    def eq(self, other, compare_array_no=True):
         if type(self) is not type(other):
             return False
         assert isinstance(other, Call1)
-        return self.unfunc is other.unfunc and self.child.eq(other.child)
+        return (self.unfunc is other.unfunc and
+                self.child.eq(other.child, compare_array_no))
 
     def debug_repr(self):
         return 'Call1(%s, %s)' % (self.name, self.child.debug_repr())
@@ -215,10 +270,15 @@
     def _invent_numbering(self, cache, allnumbers):
         self.child._invent_numbering(cache, allnumbers)
 
-    def _create_iter(self, iterlist, arr):
+    def _invent_array_numbering(self, arr, cache):
         from pypy.module.micronumpy.interp_numarray import Call1
         assert isinstance(arr, Call1)
-        self.child._create_iter(iterlist, arr.values)
+        self.child._invent_array_numbering(arr.values, cache)
+
+    def _create_iter(self, iterlist, arraylist, arr):
+        from pypy.module.micronumpy.interp_numarray import Call1
+        assert isinstance(arr, Call1)
+        self.child._create_iter(iterlist, arraylist, arr.values)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import Call1
@@ -240,24 +300,31 @@
         return (compute_hash(self.name) ^ intmask(self.left.hash() << 1) ^
                 intmask(self.right.hash() << 2))
 
-    def eq(self, other):
+    def eq(self, other, compare_array_no=True):
         if type(self) is not type(other):
             return False
         assert isinstance(other, Call2)
         return (self.binfunc is other.binfunc and
                 self.calc_dtype is other.calc_dtype and
-                self.left.eq(other.left) and self.right.eq(other.right))
+                self.left.eq(other.left, compare_array_no) and
+                self.right.eq(other.right, compare_array_no))
+
+    def _invent_array_numbering(self, arr, cache):
+        from pypy.module.micronumpy.interp_numarray import Call2
+        assert isinstance(arr, Call2)
+        self.left._invent_array_numbering(arr.left, cache)
+        self.right._invent_array_numbering(arr.right, cache)
 
     def _invent_numbering(self, cache, allnumbers):
         self.left._invent_numbering(cache, allnumbers)
         self.right._invent_numbering(cache, allnumbers)
 
-    def _create_iter(self, iterlist, arr):
+    def _create_iter(self, iterlist, arraylist, arr):
         from pypy.module.micronumpy.interp_numarray import Call2
         
         assert isinstance(arr, Call2)
-        self.left._create_iter(iterlist, arr.left)
-        self.right._create_iter(iterlist, arr.right)
+        self.left._create_iter(iterlist, arraylist, arr.left)
+        self.right._create_iter(iterlist, arraylist, arr.right)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import Call2
@@ -271,11 +338,14 @@
                                   self.right.debug_repr())
 
 class ReduceSignature(Call2):
-    def _create_iter(self, iterlist, arr):
-        self.right._create_iter(iterlist, arr)
+    def _create_iter(self, iterlist, arraylist, arr):
+        self.right._create_iter(iterlist, arraylist, arr)
 
     def _invent_numbering(self, cache, allnumbers):
         self.right._invent_numbering(cache, allnumbers)
 
+    def _invent_array_numbering(self, arr, cache):
+        self.right._invent_array_numbering(arr, cache)
+
     def eval(self, frame, arr):
         return self.right.eval(frame, arr)
diff --git a/pypy/module/micronumpy/test/test_base.py b/pypy/module/micronumpy/test/test_base.py
--- a/pypy/module/micronumpy/test/test_base.py
+++ b/pypy/module/micronumpy/test/test_base.py
@@ -14,6 +14,7 @@
         bool_dtype = get_dtype_cache(space).w_booldtype
 
         ar = W_NDimArray(10, [10], dtype=float64_dtype)
+        ar2 = W_NDimArray(10, [10], dtype=float64_dtype)
         v1 = ar.descr_add(space, ar)
         v2 = ar.descr_add(space, Scalar(float64_dtype, 2.0))
         sig1 = v1.find_sig()
@@ -21,6 +22,10 @@
         assert v1 is not v2
         assert sig1.left.iter_no == sig1.right.iter_no
         assert sig2.left.iter_no != sig2.right.iter_no
+        assert sig1.left.array_no == sig1.right.array_no
+        sig1b = ar2.descr_add(space, ar).find_sig()
+        assert sig1b.left.array_no != sig1b.right.array_no
+        assert sig1b is not sig1
         v3 = ar.descr_add(space, Scalar(float64_dtype, 1.0))
         sig3 = v3.find_sig()
         assert sig2 is sig3


More information about the pypy-commit mailing list