simple symbolic math in Python

Kragen Sitaker kragen at dnaco.net
Sun Jan 7 21:31:55 EST 2001


I posted most of this to kragen-hacks
<kragen-hacks-subscribe at kragen.dnaco.net> late last millennium.

I'm running into a problem: I can overload unary minus, but I can't
overload e.g. math.sin and math.cos, because they're not methods.  It
would be really nice to have a way to do that.  I can create objects
that act just like dictionaries or files, including working with almost
all of the built-in functions, but it doesn't look like I can create
objects that look just like numbers.

I'm thinking about implementing promises (aka deferred references) as
in E <http://www.erights.org/> as well; I'm still reading the E
documentation.

I hope folks enjoy this.

# Kragen Sitaker, 2000-12-30.
# Some simple symbolic math.
# Allows you to add, subtract, etc., simple formulas -- using the ordinary
# Python operations.  Implements all of the standard arithmetic operations.
# Once you have a formula, you can print it out as a string, evaluate it
# for particular values of its variables, call 'simplified()' to get a
# simplified version, check to see if it's exactly identical to some other
# formula, or (in some cases --- if you've used only +, -, and *, and
# fixed exponents) take its symbolic derivative.  See the test()
# routine at the end for details. Symbolic derivatives will generally
# need to be simplified() to taste very good.

# Not intended for serious use; it's just a quick hack I wrote this afternoon.

# Some things it might be fun to add:
# - a compile() method that returns a Python code object that would give you
#   faster evaluation
# - a continued-fraction output mode a la HAKMEM
# - symbolic derivatives that cover more operations
# - better simplifications ( ((x + x) + (2 * x)) should simplify to (3 * x) )
# - unary operations: negation, transcendentals
# - better printing: a + b + c + d should print as a + b + c + d, not
#   as (((a + b) + c) + d)
# - other symbolic manipulations

import math

# things inherit from Formula to get the glue that turns Python
# expressions into representations of expressions
class Formula:
    def __complex__(self): return complex(self.eval({}))
    def __int__(self): return int(self.eval({}))
    def __long__(self): return long(self.eval({}))
    def __float__(self): return float(self.eval({}))
    def __pos__(self): return self  # positive
    def __neg__(self): return Unop('-', self)
    def __add__(self, other): return Binop('+', self, other)
    def __radd__(self, other): return Binop('+', other, self)
    def __sub__(self, other): return Binop('-', self, other)
    def __rsub__(self, other): return Binop('-', other, self)
    def __mul__(self, other): return Binop('*', self, other)
    def __rmul__(self, other): return Binop('*', other, self)
    def __div__(self, other): return Binop('/', self, other)
    def __rdiv__(self, other): return Binop('/', other, self)
    def __pow__(self, other): return Binop('**', self, other)
    def __rpow__(self, other): return Binop('**', other, self)

    # one out of place: syntactic sugar for 'eval'
    # this lets me say f.where(x = 2) instead of f.eval({'x':2})
    def where(self, **vars): return self.eval(vars)

def constant(expr):
    return isinstance(expr, Constant)

# simplify an addition expression by dropping zeroes
def simplify_add(a, b):
    if a.identical(mkf(0)): return b
    elif b.identical(mkf(0)): return a
    # and this is the point at which we abandon pretenses of OO correctness
    elif isinstance(b, Unop) and b._op == '-':
        return (a - b._arg).simplified()
    else: return a + b

# simplify a multiplication expression by dropping ones and converting
# 0 * anything to 0
def simplify_multiply(a, b):
    if a.identical(mkf(0)) or b.identical(mkf(0)): return mkf(0)
    elif a.identical(mkf(1)): return b
    elif b.identical(mkf(1)): return a
    elif a.identical(mkf(-1)): return (-b).simplified()
    elif b.identical(mkf(-1)): return (-a).simplified()
    else: return a * b

def simplify_subtract(a, b):
    if b.identical(mkf(0)): return a
    elif isinstance(b, Unop) and b._op == '-':
        return (a + b._arg).simplified()
    else: return a - b

def simplify_power(a, b):
    if b.identical(mkf(0)): return mkf(1)
    elif b.identical(mkf(1)): return a
    else: return a ** b

DerivativeError = "Can't differentiate"

def power_derivative(base, exp, var):
    if not exp.derivative(var).simplified().identical(mkf(0)):
        raise DerivativeError, ("too dumb for varying exponent", exp, var)
    return exp * base ** (exp - 1) * base.derivative(var)

# Binary operation class
class Binop(Formula):
    opmap = { '+': lambda a, b: a + b,
              '*': lambda a, b: a * b,
              '-': lambda a, b: a - b,
              '/': lambda a, b: a / b,
              '**': lambda a, b: a ** b }
    def __init__(self, op, value1, value2):
        self.op = op
        self.values = mkf(value1), mkf(value2)
    def __str__(self):
        return "(%s %s %s)" % (self.values[0], self.op, self.values[1])
    def eval(self, env):
        return self.opmap[self.op](self.values[0].eval(env),
                                   self.values[1].eval(env))
    # the partial derivative with respect to some variable 'var'
    derivmap = { '+': lambda a, b, var: a.derivative(var) + b.derivative(var),
                 '-': lambda a, b, var: a.derivative(var) - b.derivative(var),
                 '*': lambda a, b, var: (a * b.derivative(var) +
                                         b * a.derivative(var)),
                 '**': power_derivative };
    def derivative(self, var):
        return self.derivmap[self.op](self.values[0], self.values[1], var)

    # very basic simplifications
    simplifymap = { '+': simplify_add,
                    '*': simplify_multiply,
                    '-': simplify_subtract,
                    '**': simplify_power};
    def simplified(self):
        values = self.values[0].simplified(), self.values[1].simplified()
        if constant(values[0]) and constant(values[1]):
            # this is kinda gross
            return mkf(Binop(self.op, values[0], values[1]).eval({}))
        elif self.simplifymap.has_key(self.op):
            return self.simplifymap[self.op](values[0], values[1])
        else:
            return Binop(self.op, values[0], values[1])

    def identical(self, other):
        return (isinstance(other, Binop) and other.op == self.op and
                other.values[0].identical(self.values[0]) and
                other.values[1].identical(self.values[1]))

class Unop(Formula):
    opmap = { '-': lambda x: -x,
              'sin': lambda x: math.sin(x),
              'cos': lambda x: math.cos(x) }
    def __init__(self, op, arg):
        self._op = op
        self._arg = mkf(arg)
    def __str__(self):
        return "%s(%s)" % (self._op, self._arg)
    def eval(self, env):
        return self.opmap[self._op](self._arg.eval(env))
    # note that each of these entries implicitly contains the chain rule;
    # that's bad and should be fixed.
    derivmap = { '-': lambda x, var: -x.derivative(var),
                 'sin': lambda x, var: x.derivative(var) * cos(x),
                 'cos': lambda x, var: -x.derivative(var) * sin(x) }
    def derivative(self, var):
        return self.derivmap[self._op](self._arg, var)
    def identical(self, other):
        return isinstance(other, Unop) and self._arg.identical(other._arg)
    def simplified(self):
        simplearg = self._arg.simplified()
        if constant(simplearg):
            return mkf(Unop(self._op, simplearg).eval({}))
        if (self._op == '-' and isinstance(simplearg, Unop) and
            simplearg._op == '-'):
            return simplearg._arg.simplified()
        else: return Unop(self._op, simplearg)

def cos(f): return Unop('cos', f)
def sin(f): return Unop('sin', f)

class Variable(Formula):
    def __init__(self, name): self._name = name
    def eval(self, environment): return environment[self._name]
    def __str__(self): return self._name
    def derivative(self, var):
        if self._name == var._name: return mkf(1)
        else: return mkf(0)
    def identical(self, other):
        return isinstance(other, Variable) and other._name == self._name
    def simplified(self): return self
class Constant(Formula):
    def __init__(self, value): self._value = value
    def eval(self, env): return self._value
    def __str__(self): return str(self._value)
    def derivative(self, var): return mkf(0)
    def identical(self, other):
        return isinstance(other, Constant) and other._value == self._value
    def simplified(self): return self

# make formula
def mkf(value):
    if type(value) in (type(1), type(1L), type(1.5), type(1j)):
        return Constant(value)
    elif type(value) is type("hi"):
        return Variable(value)
    elif isinstance(value, Formula):
        return value
    else:
        raise TypeError, ("Can't make formula from", value)

# syntactic sugar so you can say vars.foo instead of mkf('foo') or
# Variable('foo')
class Vars:
    def __getattr__(self, name): return Variable(name)
vars = Vars()

def test():
    assert mkf(2365).eval({}) == 2365
    one = mkf(1)
    assert str(one) == '1'
    assert one.eval({}) == 1
    assert isinstance(one + one, Formula)
    assert (one + one).eval({}) == 2
    assert str(one + one) == '(1 + 1)'
    x = vars.x
    assert isinstance(x, Variable)
    assert x.eval({'x': 37}) == 37
    assert (one + x).eval({'x': 108}) == 109
    assert str(one + x) == '(1 + x)'
    got_error = 0
    try:
        x.eval({})
    except KeyError:
        got_error = 1
    assert got_error
    assert (1 + one).eval({}) == 2
    assert (2 * mkf(3)).eval({}) == 6
    assert (mkf(2) * 3).eval({}) == 6
    assert (14 - one).eval({}) == 13
    assert (one - 14).eval({}) == -13
    assert int(one) == 1
    seven = (14 / mkf(2))
    assert isinstance(seven, Formula)
    assert seven.eval({}) == 7
    assert float(seven) == 7.0
    assert int(+one) == 1
    got_error = 0
    try:
        z = mkf(test)
    except TypeError:
        got_error = 1
    assert got_error
    two_to_the_x = (2 ** x)
    assert str(two_to_the_x) == '(2 ** x)'
    assert two_to_the_x.eval({'x': 20}) == 1048576
    assert two_to_the_x.where(x=20) == 1048576
    assert (x ** 2).eval({'x': 13}) == 169
    formula = (x + 1)/((x * x) - +two_to_the_x)
    assert str(formula) == '((x + 1) / ((x * x) - (2 ** x)))', str(formula)
    assert (x / 1).eval({'x': 36}) == 36
    assert long(one) == 1L
    assert complex(one) == 1+0j
    i = mkf(1j)
    assert complex(i) == 1j

    y = vars.y
    minusx = -x
    assert minusx.where(x = 5) == -5
    assert (-minusx).simplified().identical(x)
    assert str(minusx) == '-(x)'
    assert not minusx.identical(x)
    assert minusx.identical(-x)
    assert not minusx.identical(-y)

    cosx = cos(x)
    assert cosx.where(x = 0) == 1
    assert cosx.where(x = 1) != 1

    assert x.derivative(x).simplified().identical(mkf(1))
    assert x.derivative(y).simplified().identical(mkf(0))
    assert (x * y).derivative(x).simplified().identical(y)
    assert (x + y).derivative(x).simplified().identical(mkf(1))
    assert (x - y).derivative(x).simplified().identical(mkf(1))
    assert two_to_the_x.simplified().identical(two_to_the_x)
    assert (2 * x).derivative(x).simplified().identical(mkf(2))

    assert (x ** 0).simplified().identical(mkf(1))
    assert (x ** 1).simplified().identical(x)
    assert (x ** 2).derivative(x).simplified().identical(2 * x)
    assert minusx.derivative(x).simplified().eval({}) == -1
    assert cosx.derivative(x).where(x = 0) == 0
    assert (-1 * x).simplified().identical(-x)
    assert (x * -1).simplified().identical(-x)
    assert (mkf(2) * 3 * 4).simplified().identical(mkf(24))
    assert (-mkf(1)).simplified().identical(mkf(-1))
    assert (-(1 * cos(x))).simplified().identical(-cos(x))
    sinx = sin(x)
    assert (-(-(mkf(1)) * sinx)).simplified().identical(sinx)
    one = mkf(1)
    assert (-one * (one * (-one * x))).simplified().identical(x)
    assert (1 - -x).simplified().identical(1 + x)
    assert (1 - - - -x).simplified().identical(1 + x)
    
test()
-- 
<kragen at pobox.com>       Kragen Sitaker     <http://www.pobox.com/~kragen/>
Perilous to all of us are the devices of an art deeper than we possess
ourselves.
       -- Gandalf the White [J.R.R. Tolkien, "The Two Towers", Bk 3, Ch. XI]



More information about the Python-list mailing list