[Numpy-discussion] force column vector

Stefan van der Walt stefan at sun.ac.za
Wed Feb 7 09:03:50 EST 2007


On Wed, Feb 07, 2007 at 10:35:14AM +0000, Christian wrote:
> Hi,
> 
> when creating an ndarray from a list, how can I force the result to be
> 2d *and* a column vector? So in case I pass a nested list, there will be no
> modification of the shape and when I pass a simple list, it will be 
> converted to a 2d column vector. I can only think of a solution using 'if'
> clauses but I suppose there is a more elegant way.

One way is to sub-class ndarray:

import numpy as N

class ColumnVectorArray(N.ndarray):
    def __new__(cls,data):
        data = N.asarray(data).view(cls)
        if len(data.shape) == 1:
            data.shape = (-1,1)
        return data

x = ColumnVectorArray([[1,2],[3,4],[5,6]])
print 'x ='
print x
print

y = ColumnVectorArray([1,2,3])
print 'y ='
print y
print

print 'x+y ='
print x+y

which yields:


x =
[[1 2]
 [3 4]
 [5 6]]

y =
[[1]
 [2]
 [3]]

x+y =
[[2 3]
 [5 6]
 [8 9]]


Cheers
Stéfan



More information about the NumPy-Discussion mailing list