Design By Contract in python
Arthur Gordon
arthur.gordon at thales-esecurity.com
Tue Jul 24 06:23:10 EDT 2001
An implementation of Design by Contract in Python, this uses the
document strings to hold the pre and post conditions as per the paper
http://www.swe.uni-linz.ac.at/publications/abstract/TR-SE-97.24.html
Enjoy
-Arthur
--------------- cut here -----------------------------
"""FILENAME:
DesignByContract.py
DESCRIPTION:
An implementation of Design by Contract in Python 2.1
based on a paper by R. Plosch @ Johannes Kepler Universitiat
Linz.
which can be found at:
http://www.swe.uni-linz.ac.at/publications/abstract/TR-SE-97.24.html
AUTHOR:
Arthur Gordon <arthur.gordon at thales-esecurity.com>
HISTORY:
18-Jul-01: Added support for old and param namespace.
17-Jul-01: Added PrintDbCTrace - which prints out the stack
upto the DbC classes, no further.
12-Jul-01: First attempt.
"""
import types
import sys
import traceback
import copy
if sys.hexversion <0x200000:
print 'DesignByContract.py requires Python version 2.0 or
greater.'
sys.exit(0)
KEYWORD_INVARIANT = 'invar:'
KEYWORD_TYPE = 'type:'
KEYWORD_PARAMETER = 'param:'
KEYWORD_REQUIRE = 'require:'
KEYWORD_ENSURE = 'ensure:'
LEN_KEYWORD_INVARIANT = len(KEYWORD_INVARIANT)
LEN_KEYWORD_TYPE = len(KEYWORD_TYPE)
LEN_KEYWORD_PARAMETER = len(KEYWORD_PARAMETER)
LEN_KEYWORD_REQUIRE = len(KEYWORD_REQUIRE)
LEN_KEYWORD_ENSURE = len(KEYWORD_ENSURE)
TRUE = 1==1
""" Used for String to Type conversion.
asType ={ 'NoneType':types.NoneType, ...}
"""
asType ={}
for ty in dir(types): #
Create table of types
if type(eval('types.'+ty)) == types.TypeType:
asType[ty] = eval('types.'+ty)
def checkType(param, typeAsString):
""" Checks the param against the string description of the type.
Special case any accepts all types.
"""
if typeAsString == 'AnyType': # accept any
return TRUE
else:
return type(param) == asType[typeAsString] # Check
type is correct
""" Violation handling
--------------------------------------------------------------------------
Examples for catching exceptions.
try:
TestStack()
pass
except DbCFormatError,msg:
PrintDbCTrace(traceback.extract_tb(sys.exc_info()[2]),str(msg),"FORMAT
ERROR:")
except DbCViolationError,msg:
PrintDbCTrace(traceback.extract_tb(sys.exc_info()[2]),str(msg))
"""
class DbCViolationError(Exception):
""" Error: Contract Violation """
pass
class DbCFormatError(Exception):
""" Error: Incorrect Contract Format """
pass
def FormatDbCTrace(trace, line, msg= "VIOLATION:"):
""" Formats the trace without going back into the DbC classes """
retval = []
retval.append('='*60)
retval.append(msg+line)
retval.append('='*60)
FIELD_FILENAME = 0
FIELD_LINENO = 1
trace_slice = 0 # Find where trace
back to.
while not trace[trace_slice][FIELD_FILENAME].endswith('DesignByContract.py')
or\
trace[trace_slice][FIELD_LINENO] >DBC_LAST_LINE:
trace_slice += 1
retval.extend( traceback.format_list(trace[0:trace_slice]))
return retval
def PrintDbCTrace(trace, msg, str = "VIOLATION:"):
for line in FormatDbCTrace(trace,msg,str):
print line
def LogDbCTrace(trace, line, fname='violation.txt'):
f = open(fname,'w+')
f.writelines(FormatDbCTrace(trace,line))
f.close()
def contractViolationBehaviour(object,line):
""" Options write to screen
write to log file
raise Exception
"""
PrintDbCTrace(traceback.extract_stack(),line) # Pass the stack
trace
#raise DbCViolationError(line) # Raise
exception
#LogDbCTrace(traceback.extract_stack(),line) # Write to log
file
""" --------------------------------------------------------------------------
"""
def getParameters(orignal_line):
line = orignal_line
comment_start = line.find('#')
if comment_start != -1: #remove
trailing comments
line = line[:comment_start]
line =line.split(',') # split by ,
line = map(lambda x:str(x).split(':'),line) # split by :
for pair in line:
if len(pair) != 2: # should be two
raise DbCFormatError(orignal_line)
pair[0] = pair[0].strip()
pair[1] = pair[1].strip()
return line
""" print "getParameters(' list: ListType, another:IntType #
Should work')"
print getParameters(' list: ListType, another:IntType #
Should work')
"""
DBC_CLASS_WRAPPER_PRIVATE_VARS =
['_wrapped_object','_violationProc','_docstr']
class DbCClassWrapper:
""" Wrapper for classes to check Contracts. """
def __init__(self, object, violationProc =
contractViolationBehaviour):
""" param: object:InstanceType
Note the leading underscores are required on variable
names as we are
sharing name space with the wrapped object
We should check the class invaraints as we have just been
called.
"""
self._wrapped_object = object #
Reference to wrapped instance
self._violationProc = violationProc #
Who you going to call?
self._wrapped_object._dbc_wrapper = self #
Attach ourselves to the Wrapped object
if not self._wrapped_object.__doc__:
self._docstr = None
else:
self._docstr = self._wrapped_object.__doc__.splitlines()
# Turn it into a list
self._docstr = map(lambda x:x.strip(),self._docstr)
# remove leading/trailing whitespace
self.checkClassInvariants()
# Since we did not check at __init__
def __getattr__(self, name):
""" param: name:StringType """
if name in DBC_CLASS_WRAPPER_PRIVATE_VARS:
# Access private vars.
attribute = self.__dict__[name]
else:
attribute = getattr(self._wrapped_object, name)
if type(attribute) == types.MethodType: # If it
is a method...
attribute = DbCClassWrapper.DbCMethodWrapper(
self._wrapped_object, attribute) #
...Wrap it
return attribute
def __setattr__(self,name,val):
""" param: name:StringType, val:AnyType """
if name in DBC_CLASS_WRAPPER_PRIVATE_VARS:
# Access private vars.
self.__dict__[name]= val
else:
setattr(self._wrapped_object, name, val)
self.checkClassInvariants()
def checkClassInvariants(self):
""" Check the invariants for the wrapped class
Protects itself against self.docstr == None
"""
if self._docstr:
for line in self._docstr:
if line.startswith(KEYWORD_INVARIANT):
# line starts invariant:
if not eval(line[LEN_KEYWORD_INVARIANT:]):
self._violationProc(self._wrapped_object,line)
elif line.startswith(KEYWORD_TYPE):
for p in getParameters(line[LEN_KEYWORD_TYPE:]):
arg = getattr(self._wrapped_object, p[0])
if not checkType(arg,p[1]):
self._violationProc(self._wrapped_object,line)
class DbCMethodWrapper:
""" A Wrapper for methods """
def __init__(self, object, method):
""" param: object:InstanceType, method:MethodType """
self.wrapped_object = object
self.method = method
if method.__doc__: #
has documentation string
self.docstr = method.__doc__.splitlines()
# Turn it into a list
self.docstr = map(lambda x:x.strip(),self.docstr)
# remove leading/trailing whitespace
else:
self.docstr = None
self.violationProc =
self.wrapped_object._dbc_wrapper._violationProc
self.private_vars = {}
def __call__(self,*args):
""" param: args:AnyType
Only bother to call pre and post condition if you have
a doc string for this method
However we can still check the class invariants at the
end.
"""
if not self.docstr:
retval = apply(self.method,args) #
Call the method
else:
self.private_vars = self.getWrappedClassPrivateVars()
# shallow-copy
self.checkBefore(args)
# Pre-conditions
self.copyPrivateVarsToOld()
# Copy members to old
#self.displayPrivateVars()
retval = apply(self.method,args) #
Call the method
self.private_vars.update(self.getWrappedClassPrivateVars())
self.checkAfter() #
Post-conditions
del self.private_vars
self.private_vars = {}
self.wrapped_object._dbc_wrapper.checkClassInvariants()
#always check class invars !
return retval
def getWrappedClassPrivateVars(self):
retval = {}
for key in self.wrapped_object.__dict__.keys():
if not key in
['docstr','_dbc_wrapper','__builtins__']: # dont whant to copy these
objects
retval[key] = self.wrapped_object.__dict__[key]
return retval
class DbCDictWrapper:
""" class that allows us to use the old. and para.
notation in the DbC expressions."""
pass
def copyPrivateVarsToOld(self):
""" Copy methods private vars to old """
if self.private_vars: #
may not have any
old =
DbCClassWrapper.DbCMethodWrapper.DbCDictWrapper()
if self.private_vars.has_key('__builtins__'):
del self.private_vars['__builtins__']
for key in self.private_vars.keys():
if not key in ['param']: #
dont copy param over
setattr(old,key,copy.deepcopy(self.private_vars[key])) # deep-copy
self.private_vars['old'] = old
def displayPrivateVars(self):
""" for debug """
print '-'*60
for i in self.private_vars.keys():
if not i in ['__builtins__']:
print '[',i,'] =',self.private_vars[i]
if i in ['param','old']:
print ' ',self.private_vars[i].__dict__
print '-'*60
def checkBefore(self,*args):
""" param: args:AnyType
Pre-conditions
Check param, require, invars
require: self.docstr != None
"""
args = args[0] #
unpack args from tuple
param = DbCClassWrapper.DbCMethodWrapper.DbCDictWrapper()
self.private_vars['param'] = param
for line in self.docstr: # For each
line...
if line.startswith(KEYWORD_PARAMETER):
parameters =
getParameters(line[LEN_KEYWORD_PARAMETER:])
#print 'parameters',parameters
for i in range(len(parameters)):
# for each param type pair
#if checkType(type(args[i]) !=
asType[parameters[i][1]]: # Check type is correct
if not checkType(args[i],parameters[i][1]):
# Check type is correct
self.violationProc(self.wrapped_object,line)
setattr(self.private_vars['param'],parameters[i][0],args[i]) #
assign to params
elif line.startswith(KEYWORD_REQUIRE):
if not
eval(line[LEN_KEYWORD_REQUIRE:],self.private_vars):
self.violationProc(self.wrapped_object,line)
elif line.startswith(KEYWORD_INVARIANT):
if not eval(line[LEN_KEYWORD_INVARIANT:],
self.private_vars):
self.violationProc(self.wrapped_object,line)
def checkAfter(self):
""" Post-conditions
This checks the ensure and method invars
require: self.docstr != None
"""
for line in self.docstr:
if line.startswith(KEYWORD_ENSURE):
if not eval(line[LEN_KEYWORD_ENSURE:],
self.private_vars):
self.displayPrivateVars()
self.violationProc(self.wrapped_object,':"'+line+'"')
elif line.startswith(KEYWORD_INVARIANT):
if not
eval(line[LEN_KEYWORD_INVARIANT:],self.private_vars):
self.violationProc(self.wrapped_object,':"'+line+'"')
DBC_LAST_LINE = 297 # Don not dump the stack any further than this
line number for DbCErrors.
if __name__ == '__main__':
class Stack:
""" A simple test class
invar: list != None # Should work
type: list: ListType # Should work
"""
def __init__(self):
self.list = []
def ppush(self,val):
""" param: val:IntType # Should
work
require: param.val >1 # Should
work
invar: list != None # Should
work
ensure: len(list) == len(old.list)+1 # Should
work
"""
self.list.append(val)
def ppop(self):
""" invar: list != None # Should
work
ensure: len(list) == len(old.list)-1 # Should
work
"""
return self.list.pop()
def inc(self,val):
""" param: val:IntType
require: another > 1 # Should pass
ensure: old.another > 100 # Should
fail
ensure: another == old.another + 1
# Should pass
ensure: another == old.another + 10
# Should fail
"""
self.another = self.another+ 1
import unittest
class DesignByContractTests(unittest.TestCase):
def setUp(self):
def onContractViolation(object,line):
raise DbCViolationError(line) # Raise
exception
aStack = Stack() # Call the
constructor for class
self.ws = DbCClassWrapper(aStack,onContractViolation)
# And wrap it
def checkPushValidValue(self):
self.ws.ppush(10)
def checkPushInValidType(self):
self.failUnlessRaises(DbCViolationError,self.ws.ppush,'asd')
def checkPushInValidValue(self):
self.failUnlessRaises(DbCViolationError,self.ws.ppush,-1)
def checkPopValid(self):
self.ws.ppush(10)
assert self.ws.ppop() == 10
suite = unittest.makeSuite(DesignByContractTests,'check')
retval = unittest.TextTestRunner().run(suite)
More information about the Python-list
mailing list