Design By Contract in python
Steve Pike
spike at oztourer.nl
Wed Jul 25 10:11:06 EDT 2001
Interesting stuff, but I would enjoy it more if the source wasn't
mangled by line wraps. I think I have reconstituted it OK, but perhaps
you could make the original source available for download?
Thanx,
Steve Pike
On 24 Jul 2001 03:23:10 -0700, arthur.gordon at thales-esecurity.com
(Arthur Gordon) wrote:
>An implementation of Design by Contract in Python, this uses the
>document strings to hold the pre and post conditions as per the paper
>http://www.swe.uni-linz.ac.at/publications/abstract/TR-SE-97.24.html
>
>Enjoy
>
>-Arthur
>
>--------------- cut here -----------------------------
>"""FILENAME:
> DesignByContract.py
> DESCRIPTION:
> An implementation of Design by Contract in Python 2.1
> based on a paper by R. Plosch @ Johannes Kepler Universitiat
>Linz.
> which can be found at:
>http://www.swe.uni-linz.ac.at/publications/abstract/TR-SE-97.24.html
> AUTHOR:
> Arthur Gordon <arthur.gordon at thales-esecurity.com>
> HISTORY:
> 18-Jul-01: Added support for old and param namespace.
> 17-Jul-01: Added PrintDbCTrace - which prints out the stack
>upto the DbC classes, no further.
> 12-Jul-01: First attempt.
>"""
>import types
>import sys
>import traceback
>import copy
>
>if sys.hexversion <0x200000:
> print 'DesignByContract.py requires Python version 2.0 or
>greater.'
> sys.exit(0)
>
>KEYWORD_INVARIANT = 'invar:'
>KEYWORD_TYPE = 'type:'
>KEYWORD_PARAMETER = 'param:'
>KEYWORD_REQUIRE = 'require:'
>KEYWORD_ENSURE = 'ensure:'
>
>LEN_KEYWORD_INVARIANT = len(KEYWORD_INVARIANT)
>LEN_KEYWORD_TYPE = len(KEYWORD_TYPE)
>LEN_KEYWORD_PARAMETER = len(KEYWORD_PARAMETER)
>LEN_KEYWORD_REQUIRE = len(KEYWORD_REQUIRE)
>LEN_KEYWORD_ENSURE = len(KEYWORD_ENSURE)
>
>TRUE = 1==1
>
>""" Used for String to Type conversion.
> asType ={ 'NoneType':types.NoneType, ...}
>"""
>asType ={}
>for ty in dir(types): #
>Create table of types
> if type(eval('types.'+ty)) == types.TypeType:
> asType[ty] = eval('types.'+ty)
>
>
>def checkType(param, typeAsString):
> """ Checks the param against the string description of the type.
> Special case any accepts all types.
> """
> if typeAsString == 'AnyType': # accept any
> return TRUE
> else:
> return type(param) == asType[typeAsString] # Check
>type is correct
>""" Violation handling
> --------------------------------------------------------------------------
> Examples for catching exceptions.
>
> try:
> TestStack()
> pass
> except DbCFormatError,msg:
> PrintDbCTrace(traceback.extract_tb(sys.exc_info()[2]),str(msg),"FORMAT
>ERROR:")
> except DbCViolationError,msg:
> PrintDbCTrace(traceback.extract_tb(sys.exc_info()[2]),str(msg))
>
>
>"""
>class DbCViolationError(Exception):
> """ Error: Contract Violation """
> pass
>
>class DbCFormatError(Exception):
> """ Error: Incorrect Contract Format """
> pass
>
>
>def FormatDbCTrace(trace, line, msg= "VIOLATION:"):
> """ Formats the trace without going back into the DbC classes """
> retval = []
> retval.append('='*60)
> retval.append(msg+line)
> retval.append('='*60)
> FIELD_FILENAME = 0
> FIELD_LINENO = 1
> trace_slice = 0 # Find where trace
>back to.
> while not trace[trace_slice][FIELD_FILENAME].endswith('DesignByContract.py')
>or\
> trace[trace_slice][FIELD_LINENO] >DBC_LAST_LINE:
> trace_slice += 1
> retval.extend( traceback.format_list(trace[0:trace_slice]))
> return retval
>
>def PrintDbCTrace(trace, msg, str = "VIOLATION:"):
> for line in FormatDbCTrace(trace,msg,str):
> print line
>
>def LogDbCTrace(trace, line, fname='violation.txt'):
> f = open(fname,'w+')
> f.writelines(FormatDbCTrace(trace,line))
> f.close()
>
>def contractViolationBehaviour(object,line):
> """ Options write to screen
> write to log file
> raise Exception
> """
> PrintDbCTrace(traceback.extract_stack(),line) # Pass the stack
>trace
> #raise DbCViolationError(line) # Raise
>exception
> #LogDbCTrace(traceback.extract_stack(),line) # Write to log
>file
>
>
>
>
>""" --------------------------------------------------------------------------
>"""
>def getParameters(orignal_line):
> line = orignal_line
> comment_start = line.find('#')
> if comment_start != -1: #remove
>trailing comments
> line = line[:comment_start]
> line =line.split(',') # split by ,
> line = map(lambda x:str(x).split(':'),line) # split by :
> for pair in line:
> if len(pair) != 2: # should be two
> raise DbCFormatError(orignal_line)
> pair[0] = pair[0].strip()
> pair[1] = pair[1].strip()
> return line
>
>""" print "getParameters(' list: ListType, another:IntType #
>Should work')"
> print getParameters(' list: ListType, another:IntType #
>Should work')
>"""
>
>
>
>DBC_CLASS_WRAPPER_PRIVATE_VARS =
>['_wrapped_object','_violationProc','_docstr']
>class DbCClassWrapper:
> """ Wrapper for classes to check Contracts. """
>
> def __init__(self, object, violationProc =
>contractViolationBehaviour):
> """ param: object:InstanceType
>
> Note the leading underscores are required on variable
>names as we are
> sharing name space with the wrapped object
> We should check the class invaraints as we have just been
>called.
> """
> self._wrapped_object = object #
>Reference to wrapped instance
> self._violationProc = violationProc #
>Who you going to call?
> self._wrapped_object._dbc_wrapper = self #
>Attach ourselves to the Wrapped object
> if not self._wrapped_object.__doc__:
> self._docstr = None
> else:
> self._docstr = self._wrapped_object.__doc__.splitlines()
> # Turn it into a list
> self._docstr = map(lambda x:x.strip(),self._docstr)
> # remove leading/trailing whitespace
> self.checkClassInvariants()
> # Since we did not check at __init__
>
>
> def __getattr__(self, name):
> """ param: name:StringType """
> if name in DBC_CLASS_WRAPPER_PRIVATE_VARS:
> # Access private vars.
> attribute = self.__dict__[name]
> else:
> attribute = getattr(self._wrapped_object, name)
> if type(attribute) == types.MethodType: # If it
>is a method...
> attribute = DbCClassWrapper.DbCMethodWrapper(
> self._wrapped_object, attribute) #
>...Wrap it
> return attribute
>
> def __setattr__(self,name,val):
> """ param: name:StringType, val:AnyType """
> if name in DBC_CLASS_WRAPPER_PRIVATE_VARS:
> # Access private vars.
> self.__dict__[name]= val
> else:
> setattr(self._wrapped_object, name, val)
> self.checkClassInvariants()
>
> def checkClassInvariants(self):
> """ Check the invariants for the wrapped class
> Protects itself against self.docstr == None
> """
> if self._docstr:
> for line in self._docstr:
> if line.startswith(KEYWORD_INVARIANT):
> # line starts invariant:
> if not eval(line[LEN_KEYWORD_INVARIANT:]):
> self._violationProc(self._wrapped_object,line)
>
> elif line.startswith(KEYWORD_TYPE):
> for p in getParameters(line[LEN_KEYWORD_TYPE:]):
> arg = getattr(self._wrapped_object, p[0])
> if not checkType(arg,p[1]):
>
>self._violationProc(self._wrapped_object,line)
>
>
>
> class DbCMethodWrapper:
> """ A Wrapper for methods """
> def __init__(self, object, method):
> """ param: object:InstanceType, method:MethodType """
> self.wrapped_object = object
> self.method = method
> if method.__doc__: #
>has documentation string
> self.docstr = method.__doc__.splitlines()
> # Turn it into a list
> self.docstr = map(lambda x:x.strip(),self.docstr)
> # remove leading/trailing whitespace
> else:
> self.docstr = None
> self.violationProc =
>self.wrapped_object._dbc_wrapper._violationProc
> self.private_vars = {}
>
> def __call__(self,*args):
> """ param: args:AnyType
> Only bother to call pre and post condition if you have
>a doc string for this method
> However we can still check the class invariants at the
>end.
> """
> if not self.docstr:
> retval = apply(self.method,args) #
>Call the method
> else:
> self.private_vars = self.getWrappedClassPrivateVars()
> # shallow-copy
> self.checkBefore(args)
> # Pre-conditions
> self.copyPrivateVarsToOld()
> # Copy members to old
> #self.displayPrivateVars()
> retval = apply(self.method,args) #
>Call the method
>
> self.private_vars.update(self.getWrappedClassPrivateVars())
> self.checkAfter() #
>Post-conditions
> del self.private_vars
> self.private_vars = {}
> self.wrapped_object._dbc_wrapper.checkClassInvariants()
> #always check class invars !
> return retval
>
> def getWrappedClassPrivateVars(self):
> retval = {}
> for key in self.wrapped_object.__dict__.keys():
> if not key in
>['docstr','_dbc_wrapper','__builtins__']: # dont whant to copy these
>objects
> retval[key] = self.wrapped_object.__dict__[key]
> return retval
>
> class DbCDictWrapper:
> """ class that allows us to use the old. and para.
>notation in the DbC expressions."""
> pass
>
> def copyPrivateVarsToOld(self):
> """ Copy methods private vars to old """
> if self.private_vars: #
>may not have any
> old =
>DbCClassWrapper.DbCMethodWrapper.DbCDictWrapper()
> if self.private_vars.has_key('__builtins__'):
> del self.private_vars['__builtins__']
> for key in self.private_vars.keys():
> if not key in ['param']: #
>dont copy param over
>
>setattr(old,key,copy.deepcopy(self.private_vars[key])) # deep-copy
> self.private_vars['old'] = old
>
> def displayPrivateVars(self):
> """ for debug """
> print '-'*60
> for i in self.private_vars.keys():
> if not i in ['__builtins__']:
> print '[',i,'] =',self.private_vars[i]
> if i in ['param','old']:
> print ' ',self.private_vars[i].__dict__
> print '-'*60
>
> def checkBefore(self,*args):
> """ param: args:AnyType
> Pre-conditions
> Check param, require, invars
> require: self.docstr != None
> """
> args = args[0] #
>unpack args from tuple
> param = DbCClassWrapper.DbCMethodWrapper.DbCDictWrapper()
> self.private_vars['param'] = param
> for line in self.docstr: # For each
>line...
> if line.startswith(KEYWORD_PARAMETER):
> parameters =
>getParameters(line[LEN_KEYWORD_PARAMETER:])
> #print 'parameters',parameters
> for i in range(len(parameters)):
> # for each param type pair
> #if checkType(type(args[i]) !=
>asType[parameters[i][1]]: # Check type is correct
> if not checkType(args[i],parameters[i][1]):
> # Check type is correct
>
>self.violationProc(self.wrapped_object,line)
>
>setattr(self.private_vars['param'],parameters[i][0],args[i]) #
>assign to params
>
> elif line.startswith(KEYWORD_REQUIRE):
> if not
>eval(line[LEN_KEYWORD_REQUIRE:],self.private_vars):
> self.violationProc(self.wrapped_object,line)
>
> elif line.startswith(KEYWORD_INVARIANT):
> if not eval(line[LEN_KEYWORD_INVARIANT:],
>self.private_vars):
> self.violationProc(self.wrapped_object,line)
>
>
> def checkAfter(self):
> """ Post-conditions
> This checks the ensure and method invars
> require: self.docstr != None
> """
> for line in self.docstr:
> if line.startswith(KEYWORD_ENSURE):
> if not eval(line[LEN_KEYWORD_ENSURE:],
>self.private_vars):
> self.displayPrivateVars()
>
>
>self.violationProc(self.wrapped_object,':"'+line+'"')
>
> elif line.startswith(KEYWORD_INVARIANT):
> if not
>eval(line[LEN_KEYWORD_INVARIANT:],self.private_vars):
>
>self.violationProc(self.wrapped_object,':"'+line+'"')
>
>
>
>DBC_LAST_LINE = 297 # Don not dump the stack any further than this
>line number for DbCErrors.
>
>
>if __name__ == '__main__':
> class Stack:
> """ A simple test class
> invar: list != None # Should work
> type: list: ListType # Should work
> """
> def __init__(self):
> self.list = []
>
> def ppush(self,val):
> """ param: val:IntType # Should
>work
> require: param.val >1 # Should
>work
> invar: list != None # Should
>work
> ensure: len(list) == len(old.list)+1 # Should
>work
> """
> self.list.append(val)
>
> def ppop(self):
> """ invar: list != None # Should
>work
> ensure: len(list) == len(old.list)-1 # Should
>work
> """
>
> return self.list.pop()
>
> def inc(self,val):
> """ param: val:IntType
> require: another > 1 # Should pass
> ensure: old.another > 100 # Should
>fail
> ensure: another == old.another + 1
># Should pass
> ensure: another == old.another + 10
> # Should fail
> """
> self.another = self.another+ 1
>
> import unittest
> class DesignByContractTests(unittest.TestCase):
> def setUp(self):
> def onContractViolation(object,line):
> raise DbCViolationError(line) # Raise
>exception
> aStack = Stack() # Call the
>constructor for class
> self.ws = DbCClassWrapper(aStack,onContractViolation)
> # And wrap it
>
> def checkPushValidValue(self):
> self.ws.ppush(10)
>
> def checkPushInValidType(self):
> self.failUnlessRaises(DbCViolationError,self.ws.ppush,'asd')
>
> def checkPushInValidValue(self):
> self.failUnlessRaises(DbCViolationError,self.ws.ppush,-1)
>
> def checkPopValid(self):
> self.ws.ppush(10)
> assert self.ws.ppop() == 10
>
>
> suite = unittest.makeSuite(DesignByContractTests,'check')
> retval = unittest.TextTestRunner().run(suite)
More information about the Python-list
mailing list