[pypy-commit] pypy ndarray-subtype: hack at compile to support ndarray subclasses

mattip noreply at buildbot.pypy.org
Thu Jul 11 21:57:16 CEST 2013


Author: Matti Picus <matti.picus at gmail.com>
Branch: ndarray-subtype
Changeset: r65350:3464b2eff3be
Date: 2013-07-09 19:46 +0300
http://bitbucket.org/pypy/pypy/changeset/3464b2eff3be/

Log:	hack at compile to support ndarray subclasses

diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -35,10 +35,12 @@
 class BadToken(Exception):
     pass
 
+
 SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
                         "unegative", "flat", "tostring","count_nonzero",
                         "argsort"]
 TWO_ARG_FUNCTIONS = ["dot", 'take']
+TWO_ARG_FUNCTIONS_OR_NONE = ['view']
 THREE_ARG_FUNCTIONS = ['where']
 
 class W_TypeObject(W_Root):
@@ -184,17 +186,28 @@
 
     def is_true(self, w_obj):
         assert isinstance(w_obj, BoolObject)
-        return False
-        #return w_obj.boolval
+        return w_obj.boolval
 
     def is_w(self, w_obj, w_what):
         return w_obj is w_what
 
+    def issubtype(self, w_type1, w_type2):
+        if not w_type2:
+            return self.wrap(False)
+        return self.wrap(issubclass(w_type1, w_type2))
+
     def type(self, w_obj):
-        return w_obj.tp
+        try:
+            return w_obj.tp
+        except AttributeError:
+            if isinstance(w_obj, W_NDimArray):
+                return W_NDimArray
+            if issubclass(w_obj, W_NDimArray):
+                return W_NDimArray
+            return None
 
     def gettypefor(self, w_obj):
-        return None
+        return self.type(w_obj)
 
     def call_function(self, tp, w_dtype):
         return w_dtype
@@ -205,7 +218,9 @@
         return what
 
     def allocate_instance(self, klass, w_subtype):
-        return instantiate(klass)
+        inst = instantiate(klass)
+        inst.tp = klass
+        return inst
 
     def newtuple(self, list_w):
         return ListObject(list_w)
@@ -329,6 +344,8 @@
         self.name = name.strip(" ")
 
     def execute(self, interp):
+        if self.name == 'None':
+            return None
         return interp.variables[self.name]
 
     def __repr__(self):
@@ -451,6 +468,32 @@
     def __repr__(self):
         return 'slice(%s,%s,%s)' % (self.start, self.stop, self.step)
 
+class ArrayClass(Node):
+    def __init__(self):
+        self.v = W_NDimArray
+
+    def execute(self, interp):
+       return self.v
+
+    def __repr__(self):
+        return '<class W_NDimArray>'
+
+class DtypeClass(Node):
+    def __init__(self, dt):
+        self.v = dt
+
+    def execute(self, interp):
+        if self.v == 'int':
+            dtype = get_dtype_cache(interp.space).w_int64dtype
+        elif self.v == 'float':
+            dtype = get_dtype_cache(interp.space).w_float64dtype
+        else:
+            raise BadToken('unknown v to dtype "%s"' % self.v)
+        return dtype
+
+    def __repr__(self):
+        return '<class %s dtype>' % self.v
+
 class Execute(Node):
     def __init__(self, expr):
         self.expr = expr
@@ -533,6 +576,14 @@
                 w_res = where(interp.space, arr, arg1, arg2)
             else:
                 assert False
+        elif self.name in TWO_ARG_FUNCTIONS_OR_NONE:
+            if len(self.args) != 2:
+                raise ArgumentMismatch
+            arg = self.args[1].execute(interp)
+            if self.name == 'view':
+                w_res = arr.descr_view(interp.space, arg)
+            else:
+                assert False
         else:
             raise WrongFunctionName
         if isinstance(w_res, W_NDimArray):
@@ -652,6 +703,12 @@
             if token.name == 'identifier':
                 if tokens.remaining() and tokens.get(0).name == 'paren_left':
                     stack.append(self.parse_function_call(token.v, tokens))
+                elif token.v.strip(' ') == 'ndarray':
+                    stack.append(ArrayClass())
+                elif token.v.strip(' ') == 'int':
+                    stack.append(DtypeClass('int'))
+                elif token.v.strip(' ') == 'float':
+                    stack.append(DtypeClass('float'))
                 else:
                     stack.append(Variable(token.v))
             elif token.name == 'array_left':
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -260,8 +260,9 @@
         return self.implementation.get_scalar_value()
 
     def descr_copy(self, space):
-        return wrap_impl(space, space.type(self),
-                         self, self.implementation.copy(space))
+        copy = self.implementation.copy(space)
+        w_subtype = space.type(self)
+        return wrap_impl(space, w_subtype, self, copy)
 
     def descr_get_real(self, space):
         return wrap_impl(space, space.type(self), self,
@@ -629,12 +630,13 @@
             "trace not implemented yet"))
 
     def descr_view(self, space, w_dtype=None, w_type=None) :
+        print w_dtype, w_type
         if not w_type and w_dtype:
             try:
-                if w_dtype.issubtype(space.gettypefor(W_NDimArray)):
+                if space.is_true(space.issubtype(w_dtype, space.gettypefor(W_NDimArray))):
                     w_type = w_dtype
                     w_dtype = None
-            except:
+            except (OperationError, TypeError):
                 pass
         if w_dtype:
             dtype = space.interp_w(interp_dtype.W_Dtype,
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -33,11 +33,11 @@
     lhs_for_subtype = w_lhs
     rhs_for_subtype = w_rhs
     #it may be something like a FlatIter, which is not an ndarray
-    if not lhs_type.issubtype(w_ndarray):
+    if not space.is_true(space.issubtype(lhs_type, w_ndarray)):
         lhs_type = space.type(w_lhs.base)
         lhs_for_subtype = w_lhs.base
-    if not rhs_type.issubtype(w_ndarray):
-        rhs_type = space.gettypefor(w_rhs.base)
+    if not space.is_true(space.issubtype(rhs_type, w_ndarray)):
+        rhs_type = space.type(w_rhs.base)
         rhs_for_subtype = w_rhs.base
     if space.is_w(lhs_type, w_ndarray) and not space.is_w(rhs_type, w_ndarray):
         w_lhs, w_rhs = w_rhs, w_lhs
diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -2,7 +2,7 @@
 import py
 from pypy.module.micronumpy.compile import (numpy_compile, Assignment,
     ArrayConstant, FloatConstant, Operator, Variable, RangeConstant, Execute,
-    FunctionCall, FakeSpace)
+    FunctionCall, FakeSpace, W_NDimArray)
 
 
 class TestCompiler(object):
@@ -84,6 +84,7 @@
         assert interp.code.statements[0] == Assignment(
             'a', Operator(Variable('b'), "+", FloatConstant(3)))
 
+
 class TestRunner(object):
     def run(self, code):
         interp = numpy_compile(code)
@@ -290,4 +291,32 @@
         ''')
         assert interp.results[0].real == 0
         assert interp.results[0].imag == 1
-        
+
+    def test_view_none(self):
+        interp = self.run('''
+        a = [1, 0, 3, 0]
+        b = None
+        c = view(a, b)
+        c -> 0
+        ''')
+        assert interp.results[0].value == 1
+
+    def test_view_ndarray(self):
+        interp = self.run('''
+        a = [1, 0, 3, 0]
+        b = ndarray
+        c = view(a, b)
+        c
+        ''')
+        results = interp.results[0]
+        assert isinstance(results, W_NDimArray)
+
+    def test_view_dtype(self):
+        interp = self.run('''
+        a = [1, 0, 3, 0]
+        b = int
+        c = view(a, b)
+        c
+        ''')
+        results = interp.results[0]
+        assert isinstance(results, W_NDimArray)


More information about the pypy-commit mailing list