[pypy-commit] pypy ndarray-subtype: hack at compile to support ndarray subclasses
mattip
noreply at buildbot.pypy.org
Thu Jul 11 21:57:16 CEST 2013
Author: Matti Picus <matti.picus at gmail.com>
Branch: ndarray-subtype
Changeset: r65350:3464b2eff3be
Date: 2013-07-09 19:46 +0300
http://bitbucket.org/pypy/pypy/changeset/3464b2eff3be/
Log: hack at compile to support ndarray subclasses
diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -35,10 +35,12 @@
class BadToken(Exception):
pass
+
SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
"unegative", "flat", "tostring","count_nonzero",
"argsort"]
TWO_ARG_FUNCTIONS = ["dot", 'take']
+TWO_ARG_FUNCTIONS_OR_NONE = ['view']
THREE_ARG_FUNCTIONS = ['where']
class W_TypeObject(W_Root):
@@ -184,17 +186,28 @@
def is_true(self, w_obj):
assert isinstance(w_obj, BoolObject)
- return False
- #return w_obj.boolval
+ return w_obj.boolval
def is_w(self, w_obj, w_what):
return w_obj is w_what
+ def issubtype(self, w_type1, w_type2):
+ if not w_type2:
+ return self.wrap(False)
+ return self.wrap(issubclass(w_type1, w_type2))
+
def type(self, w_obj):
- return w_obj.tp
+ try:
+ return w_obj.tp
+ except AttributeError:
+ if isinstance(w_obj, W_NDimArray):
+ return W_NDimArray
+ if issubclass(w_obj, W_NDimArray):
+ return W_NDimArray
+ return None
def gettypefor(self, w_obj):
- return None
+ return self.type(w_obj)
def call_function(self, tp, w_dtype):
return w_dtype
@@ -205,7 +218,9 @@
return what
def allocate_instance(self, klass, w_subtype):
- return instantiate(klass)
+ inst = instantiate(klass)
+ inst.tp = klass
+ return inst
def newtuple(self, list_w):
return ListObject(list_w)
@@ -329,6 +344,8 @@
self.name = name.strip(" ")
def execute(self, interp):
+ if self.name == 'None':
+ return None
return interp.variables[self.name]
def __repr__(self):
@@ -451,6 +468,32 @@
def __repr__(self):
return 'slice(%s,%s,%s)' % (self.start, self.stop, self.step)
+class ArrayClass(Node):
+ def __init__(self):
+ self.v = W_NDimArray
+
+ def execute(self, interp):
+ return self.v
+
+ def __repr__(self):
+ return '<class W_NDimArray>'
+
+class DtypeClass(Node):
+ def __init__(self, dt):
+ self.v = dt
+
+ def execute(self, interp):
+ if self.v == 'int':
+ dtype = get_dtype_cache(interp.space).w_int64dtype
+ elif self.v == 'float':
+ dtype = get_dtype_cache(interp.space).w_float64dtype
+ else:
+ raise BadToken('unknown v to dtype "%s"' % self.v)
+ return dtype
+
+ def __repr__(self):
+ return '<class %s dtype>' % self.v
+
class Execute(Node):
def __init__(self, expr):
self.expr = expr
@@ -533,6 +576,14 @@
w_res = where(interp.space, arr, arg1, arg2)
else:
assert False
+ elif self.name in TWO_ARG_FUNCTIONS_OR_NONE:
+ if len(self.args) != 2:
+ raise ArgumentMismatch
+ arg = self.args[1].execute(interp)
+ if self.name == 'view':
+ w_res = arr.descr_view(interp.space, arg)
+ else:
+ assert False
else:
raise WrongFunctionName
if isinstance(w_res, W_NDimArray):
@@ -652,6 +703,12 @@
if token.name == 'identifier':
if tokens.remaining() and tokens.get(0).name == 'paren_left':
stack.append(self.parse_function_call(token.v, tokens))
+ elif token.v.strip(' ') == 'ndarray':
+ stack.append(ArrayClass())
+ elif token.v.strip(' ') == 'int':
+ stack.append(DtypeClass('int'))
+ elif token.v.strip(' ') == 'float':
+ stack.append(DtypeClass('float'))
else:
stack.append(Variable(token.v))
elif token.name == 'array_left':
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -260,8 +260,9 @@
return self.implementation.get_scalar_value()
def descr_copy(self, space):
- return wrap_impl(space, space.type(self),
- self, self.implementation.copy(space))
+ copy = self.implementation.copy(space)
+ w_subtype = space.type(self)
+ return wrap_impl(space, w_subtype, self, copy)
def descr_get_real(self, space):
return wrap_impl(space, space.type(self), self,
@@ -629,12 +630,13 @@
"trace not implemented yet"))
def descr_view(self, space, w_dtype=None, w_type=None) :
+ print w_dtype, w_type
if not w_type and w_dtype:
try:
- if w_dtype.issubtype(space.gettypefor(W_NDimArray)):
+ if space.is_true(space.issubtype(w_dtype, space.gettypefor(W_NDimArray))):
w_type = w_dtype
w_dtype = None
- except:
+ except (OperationError, TypeError):
pass
if w_dtype:
dtype = space.interp_w(interp_dtype.W_Dtype,
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -33,11 +33,11 @@
lhs_for_subtype = w_lhs
rhs_for_subtype = w_rhs
#it may be something like a FlatIter, which is not an ndarray
- if not lhs_type.issubtype(w_ndarray):
+ if not space.is_true(space.issubtype(lhs_type, w_ndarray)):
lhs_type = space.type(w_lhs.base)
lhs_for_subtype = w_lhs.base
- if not rhs_type.issubtype(w_ndarray):
- rhs_type = space.gettypefor(w_rhs.base)
+ if not space.is_true(space.issubtype(rhs_type, w_ndarray)):
+ rhs_type = space.type(w_rhs.base)
rhs_for_subtype = w_rhs.base
if space.is_w(lhs_type, w_ndarray) and not space.is_w(rhs_type, w_ndarray):
w_lhs, w_rhs = w_rhs, w_lhs
diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -2,7 +2,7 @@
import py
from pypy.module.micronumpy.compile import (numpy_compile, Assignment,
ArrayConstant, FloatConstant, Operator, Variable, RangeConstant, Execute,
- FunctionCall, FakeSpace)
+ FunctionCall, FakeSpace, W_NDimArray)
class TestCompiler(object):
@@ -84,6 +84,7 @@
assert interp.code.statements[0] == Assignment(
'a', Operator(Variable('b'), "+", FloatConstant(3)))
+
class TestRunner(object):
def run(self, code):
interp = numpy_compile(code)
@@ -290,4 +291,32 @@
''')
assert interp.results[0].real == 0
assert interp.results[0].imag == 1
-
+
+ def test_view_none(self):
+ interp = self.run('''
+ a = [1, 0, 3, 0]
+ b = None
+ c = view(a, b)
+ c -> 0
+ ''')
+ assert interp.results[0].value == 1
+
+ def test_view_ndarray(self):
+ interp = self.run('''
+ a = [1, 0, 3, 0]
+ b = ndarray
+ c = view(a, b)
+ c
+ ''')
+ results = interp.results[0]
+ assert isinstance(results, W_NDimArray)
+
+ def test_view_dtype(self):
+ interp = self.run('''
+ a = [1, 0, 3, 0]
+ b = int
+ c = view(a, b)
+ c
+ ''')
+ results = interp.results[0]
+ assert isinstance(results, W_NDimArray)
More information about the pypy-commit
mailing list