[PYTHON MATRIX-SIG] Array.py

Hinsen Konrad hinsenk@ere.umontreal.ca
Thu, 12 Oct 1995 22:05:22 -0400


# J-style array class

# Arrays are represented by a scalar or a list, possibly containing
# other lists in case of higher-rank arrays. Array rank is limited
# only by the user's patience.

# Send comments to Konrad Hinsen <hinsenk@ere.umontreal.ca>


import types, math, string, regexp, copy


######################################################################
# Error type

ArrayError = 'ArrayError'


######################################################################
# Various functions that do the real work. Classes follow.

# Construct string representation of array

def _output(data, dimension, maxlen):
    s = ''
    if dimension == 0:
	s = s + string.rjust(data,maxlen)
    elif dimension == 1:
	for e in data:
	    s = s + string.rjust(e,maxlen) + ' '
    else:
	separator = (dimension-1)*'\n'
	for e in data:
	    s = s + _output(e,dimension-1,maxlen) + separator
    i = len(s)-1
    while i > 0 and s[i] == '\n':
	i = i-1
    return s[:i+1]


# Find the shape of an array and check for consistency

def _shape(data):
    if type(data) == types.ListType:
	if data and type(data[0]) == types.ListType:
	    shapes = map(lambda x:_shape(x),data)
	    for i in range(1,len(shapes)):
		if shapes[i] != shapes[0]:
		    raise ArrayError, 'Inconsistent shapes'
	    shape = [len(data)]
	    shape = shape + shapes[0]
	    return shape
	else:
	    return [len(data)]
    else:
	return []


# Copy the data structure of an array

def _copy(data, dimension):
    if (dimension ==  0):
	return data
    else:
	c = copy.copy(data)
	for i in range(len(c)):
	    c[i] = _copy(c[i], dimension-1)
	return c


# Construct a one-dimensional list of all array elements

def __ravel(data):
    if type(data) == types.ListType:
	if len(data) and type(data[0]) == types.ListType:
	    return reduce(lambda a,b: a+b,
			  map(lambda x: __ravel(x), data))
	else:
	    return data
    else:
	return [data]

def _ravel(array):
    return Array(__ravel(array._data),
		 [reduce(lambda a,b: a*b, array._shape, 1)])


# Reshape an array

def _reshape(array, shape):
    array = _ravel(array)
    if len(shape._data) == 0:
	return take(array,0)
    else:
	array = _copy(array._data, len(array._shape))
	shape = shape._data
	n = reduce(lambda a,b: a*b, shape)
	if n > len(array):
	    nr = (n+len(array)-1)/len(array)
	    array = (nr*array)[:n]
	elif n < len(array):
	    array = array[:n]
	for i in range(len(shape)-1, 0, -1):
	    d = shape[i]
	    n = n/d
	    for j in range(n):
		array[j:j+d] = [array[j:j+d]]
	return Array(array,shape)


# Map a function on the first dimensions of an array

def _extract(a, index, dimension):
    if len(a[1]) < dimension:
	return a
    else:
	return (a[0][index], a[1][1:], a[2])

def _map(function, arglist, max_frame, scalar_flag):
    result = []
    if len(max_frame) == 0:
	if scalar_flag:
	    result = apply(function,tuple(map(lambda a: a[0], arglist)))
	else:
	    result = apply(function,tuple(map(lambda a: Array(a[0],a[2]),
					      arglist)))._data
    else:
	for index in range(max_frame[0]):
	    result.append(_map(function,
			       map(lambda a,i=index,d=len(max_frame):
				   _extract(a,i,d), arglist),
			       max_frame[1:], scalar_flag))
    return result


# Reduce an array with a given binary function

def _reduce(function, array):
    function = function[0]
    array = array[0]
    if len(array._shape) == 0:
	return array
    elif array._shape[0] == 0:
	return reshape(function._neutral, array._shape[1:])
    else:
	result = Array(array._data[0], array._shape[1:])
	for i in range(1,array._shape[0]):
	    result = function(result, Array(array._data[i], array._shape[1:]))
	return result

def _cumulative(function, array):
    function = function[0]
    array = array[0]
    if len(array._shape) == 0:
	return array
    elif array._shape[0] == 0:
	return array
    else:
	shape = array._shape
	last_result = array._data[0]
	result = [last_result]
	for i in range(1,array._shape[0]):
	    last_result = function(last_result, Array(array._data[i],
						      array._shape[1:]))._data
	    result.append(last_result)
	return Array(result, shape)


# Find the higher of two ranks

def _maxrank(a,b):
    if a == None or b == None:
	return None
    else:
	return max(a,b)

######################################################################
# Array class definition

class Array:

    def __init__(self, scalar_or_list, shape = None):
	self._data = scalar_or_list
	if shape == None:
	    self._shape = _shape(self._data)
	else:
	    self._shape = shape

    def __copy__(self):
	return Array(_copy(self._data, len(self._shape)),
		     copy.copy(self._shape))

    def __str__(self):
	s = tostr(self)
	maxstrlen = maximum.over(_strlen(_ravel(s)))._data
	return _output(s._data,len(s._shape),maxstrlen)

    __repr__ = __str__

    def __len__(self):
	if type(self._data) == types.ListType:
	    return len(self._data)
	else:
	    return 1

    def __getitem__(self, index):
	return take(self, index)

    def __getslice__(self, i, j):
	return take(self, range(i,j))

    def __add__(self, other):
	return sum(self, other)
    __radd__ = __add__

    def __sub__(self, other):
	return difference(self, other)
    def __rsub__(self, other):
	return difference(other, self)

    def __mul__(self, other):
	return product(self, other)
    __rmul__ = __mul__

    def __div__(self, other):
	return quotient(self, other)
    def __rdiv__(self, other):
	return quotient(other, self)

    def __pow__(self,other):
	return power(self, other)
    def __rpow__(self,other):
	return power(other, self)

    def __neg__(self):
	return 0-self

    def writeToFile(self, filename):
	file = open(filename, 'w')
	file.write(str(self)+'\n')
	file.close


# Check for arrayness

def isArray(x):
    return hasattr(x,'_shape')


# Read array from file

_int_pattern = regexp.compile('-?[0-9]+')
_float_pattern = regexp.compile('-?[0-9]*\\.[0-9]*([eE][+-]?[0-9]+)*')

def _match(pattern,string):
    r = pattern.match(string)
    for i in r:
	if i == (0, len(string)):
	    return 1
    return 0
    
def _convertEntry(s):
    if _match(_int_pattern,s):
	return string.atoi(s)
    elif _match(_float_pattern,s):
	return string.atof(s)
    else:
	return s

def readArray(filename):
    list = a = []
    stack = []
    blanks = 0
    file = open(filename)
    line = file.readline()
    while line:
	if line[0] != '#':
	    elements = map(_convertEntry, string.split(line))
	    if len(elements):
		if blanks:
		    while blanks > len(stack):
			a = [a]
			stack.append(a)
		    list = copy.copy([])
		    stack[blanks-1].append(list)
		    for i in range(blanks-2,-1,-1):
			list.append(copy.copy([]))
			stack[i] = list
			list = list[0]
		list.append(elements)
		blanks = 0
	    else:
		blanks = blanks + 1
	line = file.readline()
    file.close()
    while type(a) == types.ListType and len(a) == 1:
	a = a[0]
    return Array(a)


######################################################################
# Array function class

class ArrayFunction:

    def __init__(self, function, ranks, intrinsic_ranks=None):
	self._function = function
	if isArray(ranks):
	    self._ranks = ranks._data
	elif type(ranks) == types.ListType:
	    self._ranks = ranks
	else:
	    self._ranks = [ranks]
	if intrinsic_ranks == None:
	    self._intrinsic_ranks = self._ranks
	else:
	    self._intrinsic_ranks = intrinsic_ranks
	if len(self._ranks)  == 1:
	    self._ranks = len(self._intrinsic_ranks)*self._ranks
	    

    def __call__(self, *args):
	if len(self._ranks) != len(args):
	    raise ArrayError, 'Wrong number of arguments for an array function'
	arglist = []
	framelist = []
	shapelist = []
	for i in range(len(args)):
	    if isArray(args[i]):
		arglist.append(args[i])
	    else:
		arglist.append(Array(args[i]))
	    shape = arglist[i]._shape
	    rank = self._ranks[i]
	    intrinsic_rank = self._intrinsic_ranks[i]
	    if rank == None:
		cell = 0
	    elif rank < 0:
		cell = min(-rank,len(shape))
	    else:
		cell = max(0,len(shape)-rank)
	    if intrinsic_rank != None:
		cell = max(cell,len(shape)-intrinsic_rank)
	    framelist.append(shape[:cell])
	    shapelist.append(shape[cell:])
	max_frame = []
	for frame in framelist:
	    if len(frame) > len(max_frame):
		max_frame = frame
	for i in range(len(framelist)):
	    if framelist[i] != max_frame[len(max_frame)-len(framelist[i]):]:
		raise ArrayError, 'Incompatible arguments'
	scalar_function = reduce(lambda a,b:_maxrank(a,b),
				 self._intrinsic_ranks) == 0
	return Array(_map(self._function, map(lambda a,b,c: (a._data,b,c),
					      arglist, framelist, shapelist),
			  max_frame, scalar_function))

    def __getitem__(self, ranks):
	return ArrayFunction(self._function,ranks,self._intrinsic_ranks)


class BinaryArrayFunction(ArrayFunction):

    def __init__(self, function, neutral_element, ranks, intrinsic_ranks=None):
	ArrayFunction.__init__(self, function, ranks, intrinsic_ranks)
	self._neutral = neutral_element
	self.over = ArrayFunction(ArrayOperator(_reduce, [self]), [None])
	self.cumulative = ArrayFunction(ArrayOperator(_cumulative, [self]),
					[None])

    def __getitem__(self, ranks):
	return BinaryArrayFunction(self._function, self._neutral,
				   ranks, self._intrinsic_ranks)


class ArrayOperator:

    def __init__(self, operator, function_list):
	self._operator = operator
	self._functions = function_list

    def __call__(self, *args):
	return apply(self._operator, (self._functions, args))


######################################################################
# Array functions

# Functions for internal use
_strlen = ArrayFunction(len, [0])

# Structural functions
shape =   ArrayFunction(lambda a: Array(a._shape,[len(a._shape)]), [None])
reshape = ArrayFunction(_reshape, [None, 1])
ravel =   ArrayFunction(_ravel, [None])
take =    ArrayFunction(lambda a,i: Array(a._data[i._data], a._shape[1:]),
			[None, 0])

# Elementwise binary functions
_sum =        ArrayFunction(lambda a,b: a+b, [0, 0])
_difference = ArrayFunction(lambda a,b: a-b, [0, 0])
_product =    ArrayFunction(lambda a,b: a*b, [0, 0])
_quotient =   ArrayFunction(lambda a,b: a/b, [0, 0])
_power =      ArrayFunction(pow, [0, 0])
_max =        ArrayFunction(max, [0, 0])
_min =        ArrayFunction(min, [0, 0])
_smaller =    ArrayFunction(lambda a,b: a<b, [0, 0])
_greater =    ArrayFunction(lambda a,b: a>b, [0, 0])
_equal =      ArrayFunction(lambda a,b: a==b, [0, 0])
sum  =        BinaryArrayFunction(_sum, 0, [None, None])
difference  = BinaryArrayFunction(_difference, 0, [None, None])
product  =    BinaryArrayFunction(_product, 1, [None, None])
quotient  =   BinaryArrayFunction(_quotient, 1, [None, None])
power =       BinaryArrayFunction(_power, 1, [None, None])
maximum =     BinaryArrayFunction(_max, 0, [None, None])
minimum =     BinaryArrayFunction(_min, 0, [None, None])
smaller =     BinaryArrayFunction(_smaller, 1, [None, None])
greater =     BinaryArrayFunction(_greater, 1, [None, None])
equal =       BinaryArrayFunction(_equal, 1, [None, None])

# Scalar functions of one variable
tostr = ArrayFunction(str, [0])
sqrt  = ArrayFunction(math.sqrt, [0])
exp   = ArrayFunction(math.exp, [0])
log   = ArrayFunction(math.log, [0])
log10 = ArrayFunction(math.log10, [0])
sin   = ArrayFunction(math.sin, [0])
cos   = ArrayFunction(math.cos, [0])
tan   = ArrayFunction(math.tan, [0])
asin  = ArrayFunction(math.asin, [0])
acos  = ArrayFunction(math.acos, [0])
atan  = ArrayFunction(math.atan, [0])
sinh  = ArrayFunction(math.sinh, [0])
cosh  = ArrayFunction(math.cosh, [0])
tanh  = ArrayFunction(math.tanh, [0])
floor = ArrayFunction(math.floor, [0])
ceil  = ArrayFunction(math.ceil, [0])


# Nasty hack to make fix max and min safe to use.
# Without this, they would return an array as most users
# would expect, but it would not be the correct answer.
# I know I shouldn't do this, but it seems the lesser of
# two evils.

builtin_max = max
builtin_min = min

def max(*args):
    return apply(builtin_max,args)
def min(*args):
    return apply(builtin_min,args)

# test data
x = Array(range(10))

=================
MATRIX-SIG  - SIG on Matrix Math for Python

send messages to: matrix-sig@python.org
administrivia to: matrix-sig-request@python.org
=================