[pypy-commit] pypy numpy-refactor: Another refactor. This time we use __extend__ to make sure everyone can

fijal noreply at buildbot.pypy.org
Tue Sep 4 12:00:21 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57115:c504e7e46331
Date: 2012-09-04 11:59 +0200
http://bitbucket.org/pypy/pypy/changeset/c504e7e46331/

Log:	Another refactor. This time we use __extend__ to make sure everyone
	can import W_NDimArray

diff --git a/pypy/module/micronumpy/arrayimpl/__init__.py b/pypy/module/micronumpy/arrayimpl/__init__.py
--- a/pypy/module/micronumpy/arrayimpl/__init__.py
+++ b/pypy/module/micronumpy/arrayimpl/__init__.py
@@ -1,10 +0,0 @@
-
-from pypy.module.micronumpy.arrayimpl import scalar, concrete
-
-create_slice = concrete.SliceArray
-
-def create_implementation(shape, dtype, order):
-    if not shape:
-        return scalar.Scalar(dtype)
-    else:
-        return concrete.ConcreteArray(shape, dtype, order)
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -1,6 +1,7 @@
 
 from pypy.module.micronumpy.arrayimpl import base
 from pypy.module.micronumpy import support, loop
+from pypy.module.micronumpy.base import convert_to_array
 from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement,\
      calculate_broadcast_strides
 from pypy.module.micronumpy.iter import Chunk, Chunks, NewAxisChunk, RecordChunk
@@ -73,22 +74,6 @@
         return self._done
 
 
-def calc_strides(shape, dtype, order):
-    strides = []
-    backstrides = []
-    s = 1
-    shape_rev = shape[:]
-    if order == 'C':
-        shape_rev.reverse()
-    for sh in shape_rev:
-        strides.append(s * dtype.get_size())
-        backstrides.append(s * (sh - 1) * dtype.get_size())
-        s *= sh
-    if order == 'C':
-        strides.reverse()
-        backstrides.reverse()
-    return strides, backstrides
-
 def int_w(space, w_obj):
     # a special version that respects both __index__ and __int__
     # XXX add __index__ support
@@ -101,13 +86,16 @@
     start = 0
     parent = None
     
-    def __init__(self, shape, dtype, order):
+    def __init__(self, shape, dtype, order, strides, backstrides, storage=None):
         self.shape = shape
         self.size = support.product(shape) * dtype.get_size()
-        self.storage = dtype.itemtype.malloc(self.size)
-        self.strides, self.backstrides = calc_strides(shape, dtype, order)
+        if storage is None:
+            storage = dtype.itemtype.malloc(self.size)
+        self.storage = storage
         self.order = order
         self.dtype = dtype
+        self.strides = strides
+        self.backstrides = backstrides
 
     def get_shape(self):
         return self.shape
@@ -129,7 +117,8 @@
         self.dtype.fill(self.storage, box, 0, self.size)
 
     def copy(self):
-        impl = ConcreteArray(self.shape, self.dtype, self.order)
+        impl = ConcreteArray(self.shape, self.dtype, self.order, self.strides,
+                             self.backstrides)
         return loop.setslice(self.shape, impl, self)
 
     def setslice(self, space, arr):
@@ -145,7 +134,6 @@
     def get_size(self):
         return self.size // self.dtype.itemtype.get_element_size()
 
-
     def reshape(self, space, new_shape):
         # Since we got to here, prod(new_shape) == self.size
         new_strides = None
@@ -250,11 +238,15 @@
             item = self._single_item_index(space, w_index)
             self.setitem(item, self.dtype.coerce(space, w_value))
         except IndexError:
-            w_value = support.convert_to_array(space, w_value)
+            w_value = convert_to_array(space, w_value)
             chunks = self._prepare_slice_args(space, w_index)
             view = chunks.apply(self)
             view.implementation.setslice(space, w_value)
 
+    #def setshape(self, space, new_shape):
+    #    self.shape = new_shape
+    #    self.calc_strides(new_shape)
+
     def transpose(self):
         if len(self.shape) < 2:
             return self
@@ -295,3 +287,36 @@
             return OneDimViewIterator(self)
         return MultiDimViewIterator(self.parent, self.start, self.strides,
                                     self.backstrides, self.shape)
+
+    def set_shape(self, space, new_shape):
+        if len(self.shape) < 2 or self.size == 0:
+            # TODO: this code could be refactored into calc_strides
+            # but then calc_strides would have to accept a stepping factor
+            strides = []
+            backstrides = []
+            dtype = self.dtype
+            s = self.strides[0] // dtype.get_size()
+            if self.order == 'C':
+                new_shape.reverse()
+            for sh in new_shape:
+                strides.append(s * dtype.get_size())
+                backstrides.append(s * (sh - 1) * dtype.get_size())
+                s *= max(1, sh)
+            if self.order == 'C':
+                strides.reverse()
+                backstrides.reverse()
+                new_shape.reverse()
+            return SliceArray(self.start, strides, backstrides, new_shape,
+                              self)
+        new_strides = calc_new_strides(new_shape, self.shape, self.strides,
+                                       self.order)
+        if new_strides is None:
+            raise OperationError(space.w_AttributeError, space.wrap(
+                          "incompatible shape for a non-contiguous array"))
+        new_backstrides = [0] * len(new_shape)
+        for nd in range(len(new_shape)):
+            new_backstrides[nd] = (new_shape[nd] - 1) * new_strides[nd]
+        xxx
+        self.strides = new_strides[:]
+        self.backstrides = new_backstrides
+        self.shape = new_shape[:]
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -19,8 +19,8 @@
         return False
 
 class Scalar(base.BaseArrayImplementation):
-    def __init__(self, dtype):
-        self.value = None
+    def __init__(self, dtype, value=None):
+        self.value = value
         self.dtype = dtype
 
     def is_scalar(self):
@@ -32,9 +32,6 @@
     def create_iter(self, shape):
         return ScalarIterator(self.value)
 
-    def set_scalar_value(self, value):
-        self.value = value
-
     def get_scalar_value(self):
         return self.value
 
@@ -57,3 +54,6 @@
         raise OperationError(space.w_IndexError,
                              space.wrap("scalars cannot be indexed"))
         
+    def set_shape(self, new_shape):
+        import pdb
+        pdb.set_trace()
diff --git a/pypy/module/micronumpy/base.py b/pypy/module/micronumpy/base.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/base.py
@@ -0,0 +1,53 @@
+
+from pypy.interpreter.baseobjspace import Wrappable
+from pypy.tool.pairtype import extendabletype
+from pypy.module.micronumpy.support import calc_strides
+
+class W_NDimArray(Wrappable):
+    __metaclass__ = extendabletype
+
+    def __init__(self, implementation):
+        self.implementation = implementation
+    
+    @classmethod
+    def from_shape(cls, shape, dtype, order='C', storage=None):
+        from pypy.module.micronumpy.arrayimpl import concrete
+
+        assert shape
+        strides, backstrides = calc_strides(shape, dtype, order)
+        impl = concrete.ConcreteArray(shape, dtype, order, strides,
+                                      backstrides, storage)
+        return W_NDimArray(impl)
+
+    @classmethod
+    def from_strides(cls):
+        xxx
+
+    @classmethod
+    def new_slice(cls, offset, strides, backstrides, shape, parent):
+        from pypy.module.micronumpy.arrayimpl import concrete
+
+        impl = concrete.SliceArray(offset, strides, backstrides, shape, parent)
+        return W_NDimArray(impl)
+
+    @classmethod
+    def new_scalar(cls, space, dtype, w_val=None):
+        from pypy.module.micronumpy.arrayimpl import scalar
+
+        if w_val is not None:
+            w_val = dtype.coerce(space, w_val)
+        return W_NDimArray(scalar.Scalar(dtype, w_val))
+
+def convert_to_array(space, w_obj):
+    from pypy.module.micronumpy.interp_numarray import array
+    from pypy.module.micronumpy import interp_ufuncs
+    
+    if isinstance(w_obj, W_NDimArray):
+        return w_obj
+    elif space.issequence_w(w_obj):
+        # Convert to array.
+        return array(space, w_obj, w_order=None)
+    else:
+        # If it's a scalar
+        dtype = interp_ufuncs.find_dtype_for_scalar(space, w_obj)
+        return W_NDimArray.new_scalar(space, dtype, w_obj)
diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -1,5 +1,5 @@
 
-from pypy.module.micronumpy.support import convert_to_array, create_array
+from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
 from pypy.module.micronumpy import loop
 from pypy.interpreter.error import OperationError
 
@@ -70,5 +70,5 @@
     x = convert_to_array(space, w_x)
     y = convert_to_array(space, w_y)
     dtype = arr.get_dtype()
-    out = create_array(arr.get_shape(), dtype)
+    out = W_NDimArray.from_shape(arr.get_shape(), dtype)
     return loop.where(out, arr, x, y, dtype)
diff --git a/pypy/module/micronumpy/interp_boxes.py b/pypy/module/micronumpy/interp_boxes.py
--- a/pypy/module/micronumpy/interp_boxes.py
+++ b/pypy/module/micronumpy/interp_boxes.py
@@ -8,6 +8,7 @@
 from pypy.objspace.std.inttype import int_typedef
 from pypy.rlib.rarithmetic import LONG_BIT
 from pypy.tool.sourcetools import func_with_new_name
+from pypy.module.micronumpy.base import W_NDimArray
 
 MIXIN_32 = (int_typedef,) if LONG_BIT == 32 else ()
 MIXIN_64 = (int_typedef,) if LONG_BIT == 64 else ()
@@ -246,11 +247,10 @@
 
 class W_StringBox(W_CharacterBox):
     def descr__new__string_box(space, w_subtype, w_arg):
-        from pypy.module.micronumpy.interp_numarray import W_NDimArray
         from pypy.module.micronumpy.interp_dtype import new_string_dtype
 
         arg = space.str_w(space.str(w_arg))
-        arr = W_NDimArray([1], new_string_dtype(space, len(arg)))
+        arr = W_NDimArray.from_shape([1], new_string_dtype(space, len(arg)))
         for i in range(len(arg)):
             arr.storage[i] = arg[i]
         return W_StringBox(arr, 0, arr.dtype)
@@ -258,11 +258,10 @@
 
 class W_UnicodeBox(W_CharacterBox):
     def descr__new__unicode_box(space, w_subtype, w_arg):
-        from pypy.module.micronumpy.interp_numarray import W_NDimArray
         from pypy.module.micronumpy.interp_dtype import new_unicode_dtype
 
         arg = space.unicode_w(unicode_from_object(space, w_arg))
-        arr = W_NDimArray([1], new_unicode_dtype(space, len(arg)))
+        arr = W_NDimArray.from_shape([1], new_unicode_dtype(space, len(arg)))
         # XXX not this way, we need store
         #for i in range(len(arg)):
         #    arr.storage[i] = arg[i]
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1,17 +1,15 @@
 
-from pypy.interpreter.baseobjspace import Wrappable
 from pypy.interpreter.error import operationerrfmt, OperationError
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.interpreter.gateway import interp2app, unwrap_spec
+from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy import interp_dtype, interp_ufuncs, support
-from pypy.module.micronumpy.arrayimpl import create_implementation, create_slice
 from pypy.module.micronumpy.strides import find_shape_and_elems,\
      get_shape_from_iterable
 from pypy.module.micronumpy.interp_support import unwrap_axis_arg
 from pypy.module.micronumpy.appbridge import get_appbridge_cache
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib import jit
-from pypy.rlib.objectmodel import instantiate
 from pypy.rlib.rstring import StringBuilder
 
 def _find_shape(space, w_size):
@@ -22,24 +20,7 @@
         shape.append(space.int_w(w_item))
     return shape
 
-def scalar_w(space, dtype, w_object):
-    arr = W_NDimArray([], dtype)
-    arr.implementation.set_scalar_value(dtype.coerce(space, w_object))
-    return arr
-
-def slice_w(start, strides, backstrides, shape, parent):
-    arr = instantiate(W_NDimArray)
-    arr.implementation = create_slice(start, strides, backstrides, shape,
-                                      parent)
-    return arr
-
-class W_NDimArray(Wrappable):
-    def __init__(self, shape, dtype, buffer=0, offset=0, strides=None,
-                 order='C'):
-        if strides is not None or offset != 0 or buffer != 0:
-            raise Exception("unsupported args")
-        self.implementation = create_implementation(shape, dtype, order)
-
+class __extend__(W_NDimArray):
     @jit.unroll_safe
     def descr_get_shape(self, space):
         shape = self.get_shape()
@@ -49,8 +30,8 @@
         return self.implementation.get_shape()
 
     def descr_set_shape(self, space, w_new_shape):
-        self.implementation = self.implementation.set_shape(
-            _find_shape(space, w_new_shape))
+        self.implementation = self.implementation.set_shape(space,
+            get_shape_from_iterable(space, self.get_size(), w_new_shape))
 
     def get_dtype(self):
         return self.implementation.dtype
@@ -125,9 +106,7 @@
         return self.implementation.get_scalar_value()
 
     def descr_copy(self, space):
-        arr = instantiate(W_NDimArray)
-        arr.implementation = self.implementation.copy()
-        return arr
+        return W_NDimArray(self.implementation.copy())
 
     def descr_reshape(self, space, args_w):
         """reshape(...)
@@ -157,9 +136,7 @@
         return arr
 
     def descr_get_transpose(self, space):
-        arr = instantiate(W_NDimArray)
-        arr.implementation = self.implementation.transpose()
-        return arr
+        return W_NDimArray(self.implementation.transpose())
 
     # --------------------- binary operations ----------------------------
 
@@ -197,7 +174,7 @@
 
     def _binop_right_impl(ufunc_name):
         def impl(self, space, w_other, w_out=None):
-            w_other = scalar_w(space,
+            w_other = W_NDimArray.new_scalar(space,
                 interp_ufuncs.find_dtype_for_scalar(space, w_other,
                                                     self.get_dtype()),
                 w_other
@@ -243,10 +220,17 @@
 @unwrap_spec(offset=int)
 def descr_new_array(space, w_subtype, w_shape, w_dtype=None, w_buffer=None,
                     offset=0, w_strides=None, w_order=None):
+    if (offset != 0 or not space.is_w(w_strides, space.w_None) or
+        not space.is_w(w_order, space.w_None) or
+        not space.is_w(w_buffer, space.w_None)):
+        raise OperationError(space.w_NotImplementedError,
+                             space.wrap("unsupported param"))
     dtype = space.interp_w(interp_dtype.W_Dtype,
           space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype))
     shape = _find_shape(space, w_shape)
-    return W_NDimArray(shape, dtype)
+    if not shape:
+        return W_NDimArray.new_scalar(space, dtype)
+    return W_NDimArray.from_shape(shape, dtype)
 
 W_NDimArray.typedef = TypeDef(
     "ndarray",
@@ -319,7 +303,7 @@
             w_dtype = interp_ufuncs.find_dtype_for_scalar(space, w_object)
         dtype = space.interp_w(interp_dtype.W_Dtype,
           space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype))
-        return scalar_w(space, dtype, w_object)
+        return W_NDimArray.new_scalar(space, dtype, w_object)
     if w_order is None or space.is_w(w_order, space.w_None):
         order = 'C'
     else:
@@ -347,7 +331,7 @@
             dtype = interp_dtype.get_dtype_cache(space).w_float64dtype
     if ndmin > len(shape):
         shape = [1] * (ndmin - len(shape)) + shape
-    arr = W_NDimArray(shape, dtype, order=order)
+    arr = W_NDimArray.from_shape(shape, dtype, order=order)
     arr_iter = arr.create_iter(arr.get_shape())
     for w_elem in elems_w:
         arr_iter.setitem(dtype.coerce(space, w_elem))
@@ -361,8 +345,8 @@
     )
     shape = _find_shape(space, w_shape)
     if not shape:
-        return scalar_w(space, dtype, space.wrap(0))
-    return space.wrap(W_NDimArray(shape, dtype=dtype, order=order))
+        return W_NDimArray.new_scalar(space, dtype, space.wrap(0))
+    return space.wrap(W_NDimArray.from_shape(shape, dtype=dtype, order=order))
 
 @unwrap_spec(order=str)
 def ones(space, w_shape, w_dtype=None, order='C'):
@@ -371,8 +355,8 @@
     )
     shape = _find_shape(space, w_shape)
     if not shape:
-        return scalar_w(space, dtype, space.wrap(0))
-    arr = W_NDimArray(shape, dtype=dtype, order=order)
+        return W_NDimArray.new_scalar(space, dtype, space.wrap(0))
+    arr = W_NDimArray.from_shape(shape, dtype=dtype, order=order)
     one = dtype.box(1)
     arr.fill(one)
     return space.wrap(arr)
diff --git a/pypy/module/micronumpy/interp_support.py b/pypy/module/micronumpy/interp_support.py
--- a/pypy/module/micronumpy/interp_support.py
+++ b/pypy/module/micronumpy/interp_support.py
@@ -5,12 +5,11 @@
 from pypy.objspace.std.strutil import strip_spaces
 from pypy.rlib import jit
 from pypy.rlib.rarithmetic import maxint
+from pypy.module.micronumpy.base import W_NDimArray
 
 FLOAT_SIZE = rffi.sizeof(lltype.Float)
 
 def _fromstring_text(space, s, count, sep, length, dtype):
-    from pypy.module.micronumpy.interp_numarray import W_NDimArray
-
     sep_stripped = strip_spaces(sep)
     skip_bad_vals = len(sep_stripped) == 0
 
@@ -52,7 +51,7 @@
         raise OperationError(space.w_ValueError, space.wrap(
             "string is smaller than requested size"))
 
-    a = W_NDimArray([num_items], dtype=dtype)
+    a = W_NDimArray.from_shape([num_items], dtype=dtype)
     ai = a.create_iter()
     for val in items:
         a.dtype.setitem(a, ai.offset, val)
@@ -61,8 +60,6 @@
     return space.wrap(a)
 
 def _fromstring_bin(space, s, count, length, dtype):
-    from pypy.module.micronumpy.interp_numarray import W_NDimArray
-    
     itemsize = dtype.itemtype.get_element_size()
     assert itemsize >= 0
     if count == -1:
@@ -75,7 +72,7 @@
         raise OperationError(space.w_ValueError, space.wrap(
             "string is smaller than requested size"))
         
-    a = W_NDimArray([count], dtype=dtype)
+    a = W_NDimArray.from_shape([count], dtype=dtype)
     fromstring_loop(a, dtype, itemsize, s)
     return space.wrap(a)
 
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -8,7 +8,7 @@
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.module.micronumpy.interp_support import unwrap_axis_arg
 from pypy.module.micronumpy.strides import shape_agreement
-from pypy.module.micronumpy.support import convert_to_array
+from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
 
 def done_if_true(dtype, val):
     return dtype.itemtype.bool(val)
@@ -38,7 +38,6 @@
         return self.identity
 
     def descr_call(self, space, __args__):
-        from interp_numarray import W_NDimArray
         args_w, kwds_w = __args__.unpack()
         # it occurs to me that we don't support any datatypes that
         # require casting, change it later when we do
@@ -141,8 +140,6 @@
 
     def reduce(self, space, w_obj, multidim, promote_to_largest, w_axis,
                keepdims=False, out=None):
-        from pypy.module.micronumpy.interp_numarray import W_NDimArray
-
         if self.argcount != 2:
             raise OperationError(space.w_ValueError, space.wrap("reduce only "
                 "supported for binary functions"))
@@ -196,7 +193,7 @@
                 #        "mismatched  dtypes"))
                 return self.do_axis_reduce(obj, out.find_dtype(), axis, out)
             else:
-                result = W_NDimArray(shape, dtype)
+                result = W_NDimArray.from_shape(shape, dtype)
                 return self.do_axis_reduce(obj, dtype, axis, result)
         if out:
             if len(out.get_shape())>0:
@@ -232,7 +229,6 @@
         self.bool_result = bool_result
 
     def call(self, space, args_w):
-        from pypy.module.micronumpy.interp_numarray import W_NDimArray
         if len(args_w)<2:
             [w_obj] = args_w
             out = None
@@ -289,8 +285,6 @@
 
     @jit.unroll_safe
     def call(self, space, args_w):
-        from pypy.module.micronumpy.interp_numarray import W_NDimArray
-        
         if len(args_w) > 2:
             [w_lhs, w_rhs, w_out] = args_w
         else:
diff --git a/pypy/module/micronumpy/iter.py b/pypy/module/micronumpy/iter.py
--- a/pypy/module/micronumpy/iter.py
+++ b/pypy/module/micronumpy/iter.py
@@ -44,6 +44,7 @@
 
 from pypy.module.micronumpy.strides import enumerate_chunks,\
      calculate_slice_strides
+from pypy.module.micronumpy.base import W_NDimArray
 from pypy.rlib import jit
 
 # structures to describe slicing
@@ -56,13 +57,12 @@
         self.name = name
 
     def apply(self, arr):
-        from pypy.module.micronumpy.interp_numarray import slice_w
-
         arr = arr.get_concrete()
         ofs, subdtype = arr.dtype.fields[self.name]
         # strides backstrides are identical, ofs only changes start
-        return slice_w(arr.start + ofs, arr.strides[:], arr.backstrides[:],
-                       arr.shape[:], arr, subdtype)
+        return W_NDimArray.new_slice(arr.start + ofs, arr.strides[:],
+                                     arr.backstrides[:],
+                                     arr.shape[:], arr, subdtype)
 
 class Chunks(BaseChunk):
     def __init__(self, l):
@@ -80,14 +80,12 @@
         return shape[:] + old_shape[s:]
 
     def apply(self, arr):
-        from pypy.module.micronumpy.interp_numarray import slice_w
-
         shape = self.extend_shape(arr.shape)
         r = calculate_slice_strides(arr.shape, arr.start, arr.strides,
                                     arr.backstrides, self.l)
         _, start, strides, backstrides = r
-        return slice_w(start, strides[:], backstrides[:],
-                           shape[:], arr)
+        return W_NDimArray.new_slice(start, strides[:], backstrides[:],
+                                     shape[:], arr)
 
 
 class Chunk(BaseChunk):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -3,11 +3,11 @@
 signatures
 """
 
-from pypy.module.micronumpy.support import create_array
+from pypy.module.micronumpy.base import W_NDimArray
 
 def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     if out is None:
-        out = create_array(shape, res_dtype)
+        out = W_NDimArray.from_shape(shape, res_dtype)
     left_iter = w_lhs.create_iter(shape)
     right_iter = w_rhs.create_iter(shape)
     out_iter = out.create_iter(shape)
@@ -23,7 +23,7 @@
 
 def call1(shape, func, name , calc_dtype, res_dtype, w_obj, out):
     if out is None:
-        out = create_array(shape, res_dtype)
+        out = W_NDimArray.from_shape(shape, res_dtype)
     obj_iter = w_obj.create_iter(shape)
     out_iter = out.create_iter(shape)
     while not out_iter.done():
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -1,5 +1,6 @@
 from pypy.rlib import jit
 from pypy.interpreter.error import OperationError
+from pypy.module.micronumpy.base import W_NDimArray
 
 @jit.look_inside_iff(lambda chunks: jit.isconstant(len(chunks)))
 def enumerate_chunks(chunks):
@@ -48,7 +49,6 @@
     return rstrides, rbackstrides
 
 def is_single_elem(space, w_elem, is_rec_type):
-    from pypy.module.micronumpy.interp_numarray import W_NDimArray
     if (is_rec_type and space.isinstance_w(w_elem, space.w_tuple)):
         return True
     if (space.isinstance_w(w_elem, space.w_tuple) or
@@ -105,8 +105,6 @@
     return coords, step, lngth
 
 def shape_agreement(space, shape1, w_arr2):
-    from pypy.module.micronumpy.interp_numarray import W_NDimArray
-
     if w_arr2 is None:
         return shape1
     assert isinstance(w_arr2, W_NDimArray)
@@ -171,12 +169,6 @@
         neg_dim = -1
         batch = space.listview(w_iterable)
         new_size = 1
-        if len(batch) < 1:
-            if old_size == 1:
-                # Scalars can have an empty size.
-                new_size = 1
-            else:
-                new_size = 0
         new_shape = []
         i = 0
         for elem in batch:
diff --git a/pypy/module/micronumpy/support.py b/pypy/module/micronumpy/support.py
--- a/pypy/module/micronumpy/support.py
+++ b/pypy/module/micronumpy/support.py
@@ -8,24 +8,18 @@
         i *= x
     return i
 
-def convert_to_array(space, w_obj):
-    from pypy.module.micronumpy.interp_numarray import W_NDimArray, array,\
-         scalar_w
-    from pypy.module.micronumpy import interp_ufuncs
-    
-    if isinstance(w_obj, W_NDimArray):
-        return w_obj
-    elif space.issequence_w(w_obj):
-        # Convert to array.
-        return array(space, w_obj, w_order=None)
-    else:
-        # If it's a scalar
-        dtype = interp_ufuncs.find_dtype_for_scalar(space, w_obj)
-        return scalar_w(space, dtype, w_obj)
-
-def create_array(shape, dtype):
-    """ Convinient shortcut to avoid circular imports
-    """
-    from pypy.module.micronumpy.interp_numarray import W_NDimArray
-    
-    return W_NDimArray(shape, dtype)
+def calc_strides(shape, dtype, order):
+    strides = []
+    backstrides = []
+    s = 1
+    shape_rev = shape[:]
+    if order == 'C':
+        shape_rev.reverse()
+    for sh in shape_rev:
+        strides.append(s * dtype.get_size())
+        backstrides.append(s * (sh - 1) * dtype.get_size())
+        s *= sh
+    if order == 'C':
+        strides.reverse()
+        backstrides.reverse()
+    return strides, backstrides
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -26,7 +26,7 @@
     return Chunks(chunks).apply(a).implementation
 
 def create_array(*args, **kwargs):
-    return W_NDimArray(*args, **kwargs).implementation
+    return W_NDimArray.from_shape(*args, **kwargs).implementation
 
 class TestNumArrayDirect(object):
     def newslice(self, *args):
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -4,12 +4,11 @@
 
 from pypy.interpreter.error import OperationError
 from pypy.module.micronumpy import interp_boxes
-from pypy.module.micronumpy.support import create_array
 from pypy.objspace.std.floatobject import float2string
 from pypy.rlib import rfloat, clibffi
 from pypy.rlib.rawstorage import (alloc_raw_storage, raw_storage_setitem,
                                   raw_storage_getitem)
-from pypy.rlib.objectmodel import specialize, we_are_translated
+from pypy.rlib.objectmodel import specialize
 from pypy.rlib.rarithmetic import widen, byteswap
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.rlib.rstruct.runpack import runpack


More information about the pypy-commit mailing list