simple symbolic math in Python
Kragen Sitaker
kragen at dnaco.net
Sun Jan 7 21:31:55 EST 2001
I posted most of this to kragen-hacks
<kragen-hacks-subscribe at kragen.dnaco.net> late last millennium.
I'm running into a problem: I can overload unary minus, but I can't
overload e.g. math.sin and math.cos, because they're not methods. It
would be really nice to have a way to do that. I can create objects
that act just like dictionaries or files, including working with almost
all of the built-in functions, but it doesn't look like I can create
objects that look just like numbers.
I'm thinking about implementing promises (aka deferred references) as
in E <http://www.erights.org/> as well; I'm still reading the E
documentation.
I hope folks enjoy this.
# Kragen Sitaker, 2000-12-30.
# Some simple symbolic math.
# Allows you to add, subtract, etc., simple formulas -- using the ordinary
# Python operations. Implements all of the standard arithmetic operations.
# Once you have a formula, you can print it out as a string, evaluate it
# for particular values of its variables, call 'simplified()' to get a
# simplified version, check to see if it's exactly identical to some other
# formula, or (in some cases --- if you've used only +, -, and *, and
# fixed exponents) take its symbolic derivative. See the test()
# routine at the end for details. Symbolic derivatives will generally
# need to be simplified() to taste very good.
# Not intended for serious use; it's just a quick hack I wrote this afternoon.
# Some things it might be fun to add:
# - a compile() method that returns a Python code object that would give you
# faster evaluation
# - a continued-fraction output mode a la HAKMEM
# - symbolic derivatives that cover more operations
# - better simplifications ( ((x + x) + (2 * x)) should simplify to (3 * x) )
# - unary operations: negation, transcendentals
# - better printing: a + b + c + d should print as a + b + c + d, not
# as (((a + b) + c) + d)
# - other symbolic manipulations
import math
# things inherit from Formula to get the glue that turns Python
# expressions into representations of expressions
class Formula:
def __complex__(self): return complex(self.eval({}))
def __int__(self): return int(self.eval({}))
def __long__(self): return long(self.eval({}))
def __float__(self): return float(self.eval({}))
def __pos__(self): return self # positive
def __neg__(self): return Unop('-', self)
def __add__(self, other): return Binop('+', self, other)
def __radd__(self, other): return Binop('+', other, self)
def __sub__(self, other): return Binop('-', self, other)
def __rsub__(self, other): return Binop('-', other, self)
def __mul__(self, other): return Binop('*', self, other)
def __rmul__(self, other): return Binop('*', other, self)
def __div__(self, other): return Binop('/', self, other)
def __rdiv__(self, other): return Binop('/', other, self)
def __pow__(self, other): return Binop('**', self, other)
def __rpow__(self, other): return Binop('**', other, self)
# one out of place: syntactic sugar for 'eval'
# this lets me say f.where(x = 2) instead of f.eval({'x':2})
def where(self, **vars): return self.eval(vars)
def constant(expr):
return isinstance(expr, Constant)
# simplify an addition expression by dropping zeroes
def simplify_add(a, b):
if a.identical(mkf(0)): return b
elif b.identical(mkf(0)): return a
# and this is the point at which we abandon pretenses of OO correctness
elif isinstance(b, Unop) and b._op == '-':
return (a - b._arg).simplified()
else: return a + b
# simplify a multiplication expression by dropping ones and converting
# 0 * anything to 0
def simplify_multiply(a, b):
if a.identical(mkf(0)) or b.identical(mkf(0)): return mkf(0)
elif a.identical(mkf(1)): return b
elif b.identical(mkf(1)): return a
elif a.identical(mkf(-1)): return (-b).simplified()
elif b.identical(mkf(-1)): return (-a).simplified()
else: return a * b
def simplify_subtract(a, b):
if b.identical(mkf(0)): return a
elif isinstance(b, Unop) and b._op == '-':
return (a + b._arg).simplified()
else: return a - b
def simplify_power(a, b):
if b.identical(mkf(0)): return mkf(1)
elif b.identical(mkf(1)): return a
else: return a ** b
DerivativeError = "Can't differentiate"
def power_derivative(base, exp, var):
if not exp.derivative(var).simplified().identical(mkf(0)):
raise DerivativeError, ("too dumb for varying exponent", exp, var)
return exp * base ** (exp - 1) * base.derivative(var)
# Binary operation class
class Binop(Formula):
opmap = { '+': lambda a, b: a + b,
'*': lambda a, b: a * b,
'-': lambda a, b: a - b,
'/': lambda a, b: a / b,
'**': lambda a, b: a ** b }
def __init__(self, op, value1, value2):
self.op = op
self.values = mkf(value1), mkf(value2)
def __str__(self):
return "(%s %s %s)" % (self.values[0], self.op, self.values[1])
def eval(self, env):
return self.opmap[self.op](self.values[0].eval(env),
self.values[1].eval(env))
# the partial derivative with respect to some variable 'var'
derivmap = { '+': lambda a, b, var: a.derivative(var) + b.derivative(var),
'-': lambda a, b, var: a.derivative(var) - b.derivative(var),
'*': lambda a, b, var: (a * b.derivative(var) +
b * a.derivative(var)),
'**': power_derivative };
def derivative(self, var):
return self.derivmap[self.op](self.values[0], self.values[1], var)
# very basic simplifications
simplifymap = { '+': simplify_add,
'*': simplify_multiply,
'-': simplify_subtract,
'**': simplify_power};
def simplified(self):
values = self.values[0].simplified(), self.values[1].simplified()
if constant(values[0]) and constant(values[1]):
# this is kinda gross
return mkf(Binop(self.op, values[0], values[1]).eval({}))
elif self.simplifymap.has_key(self.op):
return self.simplifymap[self.op](values[0], values[1])
else:
return Binop(self.op, values[0], values[1])
def identical(self, other):
return (isinstance(other, Binop) and other.op == self.op and
other.values[0].identical(self.values[0]) and
other.values[1].identical(self.values[1]))
class Unop(Formula):
opmap = { '-': lambda x: -x,
'sin': lambda x: math.sin(x),
'cos': lambda x: math.cos(x) }
def __init__(self, op, arg):
self._op = op
self._arg = mkf(arg)
def __str__(self):
return "%s(%s)" % (self._op, self._arg)
def eval(self, env):
return self.opmap[self._op](self._arg.eval(env))
# note that each of these entries implicitly contains the chain rule;
# that's bad and should be fixed.
derivmap = { '-': lambda x, var: -x.derivative(var),
'sin': lambda x, var: x.derivative(var) * cos(x),
'cos': lambda x, var: -x.derivative(var) * sin(x) }
def derivative(self, var):
return self.derivmap[self._op](self._arg, var)
def identical(self, other):
return isinstance(other, Unop) and self._arg.identical(other._arg)
def simplified(self):
simplearg = self._arg.simplified()
if constant(simplearg):
return mkf(Unop(self._op, simplearg).eval({}))
if (self._op == '-' and isinstance(simplearg, Unop) and
simplearg._op == '-'):
return simplearg._arg.simplified()
else: return Unop(self._op, simplearg)
def cos(f): return Unop('cos', f)
def sin(f): return Unop('sin', f)
class Variable(Formula):
def __init__(self, name): self._name = name
def eval(self, environment): return environment[self._name]
def __str__(self): return self._name
def derivative(self, var):
if self._name == var._name: return mkf(1)
else: return mkf(0)
def identical(self, other):
return isinstance(other, Variable) and other._name == self._name
def simplified(self): return self
class Constant(Formula):
def __init__(self, value): self._value = value
def eval(self, env): return self._value
def __str__(self): return str(self._value)
def derivative(self, var): return mkf(0)
def identical(self, other):
return isinstance(other, Constant) and other._value == self._value
def simplified(self): return self
# make formula
def mkf(value):
if type(value) in (type(1), type(1L), type(1.5), type(1j)):
return Constant(value)
elif type(value) is type("hi"):
return Variable(value)
elif isinstance(value, Formula):
return value
else:
raise TypeError, ("Can't make formula from", value)
# syntactic sugar so you can say vars.foo instead of mkf('foo') or
# Variable('foo')
class Vars:
def __getattr__(self, name): return Variable(name)
vars = Vars()
def test():
assert mkf(2365).eval({}) == 2365
one = mkf(1)
assert str(one) == '1'
assert one.eval({}) == 1
assert isinstance(one + one, Formula)
assert (one + one).eval({}) == 2
assert str(one + one) == '(1 + 1)'
x = vars.x
assert isinstance(x, Variable)
assert x.eval({'x': 37}) == 37
assert (one + x).eval({'x': 108}) == 109
assert str(one + x) == '(1 + x)'
got_error = 0
try:
x.eval({})
except KeyError:
got_error = 1
assert got_error
assert (1 + one).eval({}) == 2
assert (2 * mkf(3)).eval({}) == 6
assert (mkf(2) * 3).eval({}) == 6
assert (14 - one).eval({}) == 13
assert (one - 14).eval({}) == -13
assert int(one) == 1
seven = (14 / mkf(2))
assert isinstance(seven, Formula)
assert seven.eval({}) == 7
assert float(seven) == 7.0
assert int(+one) == 1
got_error = 0
try:
z = mkf(test)
except TypeError:
got_error = 1
assert got_error
two_to_the_x = (2 ** x)
assert str(two_to_the_x) == '(2 ** x)'
assert two_to_the_x.eval({'x': 20}) == 1048576
assert two_to_the_x.where(x=20) == 1048576
assert (x ** 2).eval({'x': 13}) == 169
formula = (x + 1)/((x * x) - +two_to_the_x)
assert str(formula) == '((x + 1) / ((x * x) - (2 ** x)))', str(formula)
assert (x / 1).eval({'x': 36}) == 36
assert long(one) == 1L
assert complex(one) == 1+0j
i = mkf(1j)
assert complex(i) == 1j
y = vars.y
minusx = -x
assert minusx.where(x = 5) == -5
assert (-minusx).simplified().identical(x)
assert str(minusx) == '-(x)'
assert not minusx.identical(x)
assert minusx.identical(-x)
assert not minusx.identical(-y)
cosx = cos(x)
assert cosx.where(x = 0) == 1
assert cosx.where(x = 1) != 1
assert x.derivative(x).simplified().identical(mkf(1))
assert x.derivative(y).simplified().identical(mkf(0))
assert (x * y).derivative(x).simplified().identical(y)
assert (x + y).derivative(x).simplified().identical(mkf(1))
assert (x - y).derivative(x).simplified().identical(mkf(1))
assert two_to_the_x.simplified().identical(two_to_the_x)
assert (2 * x).derivative(x).simplified().identical(mkf(2))
assert (x ** 0).simplified().identical(mkf(1))
assert (x ** 1).simplified().identical(x)
assert (x ** 2).derivative(x).simplified().identical(2 * x)
assert minusx.derivative(x).simplified().eval({}) == -1
assert cosx.derivative(x).where(x = 0) == 0
assert (-1 * x).simplified().identical(-x)
assert (x * -1).simplified().identical(-x)
assert (mkf(2) * 3 * 4).simplified().identical(mkf(24))
assert (-mkf(1)).simplified().identical(mkf(-1))
assert (-(1 * cos(x))).simplified().identical(-cos(x))
sinx = sin(x)
assert (-(-(mkf(1)) * sinx)).simplified().identical(sinx)
one = mkf(1)
assert (-one * (one * (-one * x))).simplified().identical(x)
assert (1 - -x).simplified().identical(1 + x)
assert (1 - - - -x).simplified().identical(1 + x)
test()
--
<kragen at pobox.com> Kragen Sitaker <http://www.pobox.com/~kragen/>
Perilous to all of us are the devices of an art deeper than we possess
ourselves.
-- Gandalf the White [J.R.R. Tolkien, "The Two Towers", Bk 3, Ch. XI]
More information about the Python-list
mailing list