Design By Contract in python

Steve Pike spike at oztourer.nl
Wed Jul 25 10:11:06 EDT 2001


Interesting stuff, but I would enjoy it more if the source wasn't
mangled by line wraps. I think I have reconstituted it OK, but perhaps
you could make the original source available for download?

Thanx,
Steve Pike


On 24 Jul 2001 03:23:10 -0700, arthur.gordon at thales-esecurity.com
(Arthur Gordon) wrote:

>An implementation of Design by Contract in Python, this uses the
>document strings to hold the pre and post conditions as per the paper
>http://www.swe.uni-linz.ac.at/publications/abstract/TR-SE-97.24.html
>
>Enjoy
>
>-Arthur
>
>--------------- cut here -----------------------------
>"""FILENAME:
>        DesignByContract.py
>   DESCRIPTION:
>        An implementation of Design by Contract in Python 2.1 
>        based on a paper by R. Plosch @ Johannes Kepler Universitiat
>Linz.
>        which can be found at:
>http://www.swe.uni-linz.ac.at/publications/abstract/TR-SE-97.24.html
>   AUTHOR:
>       Arthur Gordon <arthur.gordon at thales-esecurity.com>
>   HISTORY:
>       18-Jul-01: Added support for old and param namespace.
>       17-Jul-01: Added PrintDbCTrace - which prints out the stack
>upto the DbC classes, no further.
>       12-Jul-01: First attempt.
>"""
>import types
>import sys
>import traceback
>import copy
>
>if sys.hexversion <0x200000:
>    print 'DesignByContract.py requires Python version 2.0 or
>greater.'
>    sys.exit(0)
>
>KEYWORD_INVARIANT   = 'invar:'
>KEYWORD_TYPE        = 'type:'
>KEYWORD_PARAMETER   = 'param:'
>KEYWORD_REQUIRE     = 'require:'
>KEYWORD_ENSURE      = 'ensure:'
>
>LEN_KEYWORD_INVARIANT   = len(KEYWORD_INVARIANT)
>LEN_KEYWORD_TYPE        = len(KEYWORD_TYPE)
>LEN_KEYWORD_PARAMETER   = len(KEYWORD_PARAMETER)
>LEN_KEYWORD_REQUIRE     = len(KEYWORD_REQUIRE)
>LEN_KEYWORD_ENSURE      = len(KEYWORD_ENSURE)
>
>TRUE = 1==1
>
>""" Used for String to Type conversion.
>    asType ={   'NoneType':types.NoneType, ...}
>"""
>asType ={}
>for ty in dir(types):                                           #
>Create table of types
>    if type(eval('types.'+ty)) == types.TypeType:
>        asType[ty] = eval('types.'+ty)
>
>
>def checkType(param, typeAsString):
>    """ Checks the param against the string description of the type.
>        Special case any accepts all types.
>    """
>    if typeAsString == 'AnyType':           # accept any
>        return TRUE        
>    else:
>        return type(param) == asType[typeAsString]           # Check
>type is correct
>""" Violation handling
>    --------------------------------------------------------------------------
>    Examples for catching exceptions.
>
>    try:
>       TestStack()
>        pass
>    except DbCFormatError,msg:
>        PrintDbCTrace(traceback.extract_tb(sys.exc_info()[2]),str(msg),"FORMAT
>ERROR:")
>    except DbCViolationError,msg:
>        PrintDbCTrace(traceback.extract_tb(sys.exc_info()[2]),str(msg))
>
>
>"""
>class DbCViolationError(Exception):
>    """ Error: Contract Violation """
>    pass
>
>class DbCFormatError(Exception):
>    """ Error: Incorrect Contract Format """
>    pass
>
>
>def FormatDbCTrace(trace, line, msg= "VIOLATION:"):
>    """ Formats the trace without going back into the DbC classes """
>    retval = []
>    retval.append('='*60)
>    retval.append(msg+line)
>    retval.append('='*60)
>    FIELD_FILENAME  = 0
>    FIELD_LINENO    = 1
>    trace_slice = 0                             # Find where trace
>back to.
>    while not trace[trace_slice][FIELD_FILENAME].endswith('DesignByContract.py')
>or\
>              trace[trace_slice][FIELD_LINENO] >DBC_LAST_LINE:
>        trace_slice += 1                                
>    retval.extend( traceback.format_list(trace[0:trace_slice]))
>    return retval
>
>def PrintDbCTrace(trace, msg, str = "VIOLATION:"):
>    for line in  FormatDbCTrace(trace,msg,str):
>        print line
>
>def LogDbCTrace(trace, line, fname='violation.txt'):
>    f = open(fname,'w+')
>    f.writelines(FormatDbCTrace(trace,line))
>    f.close()    
>
>def contractViolationBehaviour(object,line):
>    """	Options write to screen 
>            write to log file
>            raise Exception
>    """
>    PrintDbCTrace(traceback.extract_stack(),line)  # Pass the stack
>trace
>    #raise DbCViolationError(line)                            # Raise
>exception
>    #LogDbCTrace(traceback.extract_stack(),line)    # Write to log
>file
>    
>
>
>
>""" --------------------------------------------------------------------------
>"""
>def getParameters(orignal_line):
>    line = orignal_line
>    comment_start = line.find('#')        
>    if comment_start != -1:                                   #remove
>trailing comments
>        line = line[:comment_start]
>    line =line.split(',')                           # split by ,
>    line = map(lambda x:str(x).split(':'),line)     # split by :
>    for pair in line:
>        if len(pair) != 2:                          # should be two
>            raise DbCFormatError(orignal_line)
>        pair[0] = pair[0].strip()
>        pair[1] = pair[1].strip()
>    return line
>
>""" print "getParameters(' list: ListType, another:IntType       #
>Should work')"
>    print getParameters(' list: ListType, another:IntType       #
>Should work')
>"""
>
>
>
>DBC_CLASS_WRAPPER_PRIVATE_VARS =
>['_wrapped_object','_violationProc','_docstr']
>class DbCClassWrapper:
>    """ Wrapper for classes to check Contracts. """
>    
>    def __init__(self, object, violationProc =
>contractViolationBehaviour):
>        """ param: object:InstanceType
>        
>            Note the leading underscores are required on variable
>names as we are
>            sharing name space with the wrapped object
>            We should check the class invaraints as we have just been
>called.
>        """
>        self._wrapped_object = object                               #
>Reference to wrapped instance
>        self._violationProc = violationProc                         #
>Who you going to call?
>        self._wrapped_object._dbc_wrapper = self                    #
>Attach ourselves to the Wrapped object
>        if not self._wrapped_object.__doc__:
>            self._docstr = None
>        else:
>            self._docstr = self._wrapped_object.__doc__.splitlines()  
> # Turn it into a list
>            self._docstr = map(lambda x:x.strip(),self._docstr)       
> # remove leading/trailing whitespace
>            self.checkClassInvariants()                               
> # Since we did not check at __init__
>            
>                            
>    def __getattr__(self, name):
>        """ param: name:StringType """
>        if name in DBC_CLASS_WRAPPER_PRIVATE_VARS:		                  
>         # Access private vars.
>            attribute = self.__dict__[name]
>        else:
>            attribute = getattr(self._wrapped_object, name)
>        if type(attribute) == types.MethodType:			            # If it
>is a method...
>            attribute = DbCClassWrapper.DbCMethodWrapper(
>                           self._wrapped_object, attribute)         #
>...Wrap it
>        return attribute
>    
>    def __setattr__(self,name,val):
>        """ param: name:StringType, val:AnyType """
>        if name in DBC_CLASS_WRAPPER_PRIVATE_VARS:                    
>               # Access private vars.
>            self.__dict__[name]= val
>        else:
>            setattr(self._wrapped_object, name, val)
>            self.checkClassInvariants()
>
>    def checkClassInvariants(self):        
>        """ Check the invariants for the wrapped class
>            Protects itself against self.docstr == None
>        """
>        if self._docstr:    
>            for line in self._docstr:
>                if line.startswith(KEYWORD_INVARIANT):                
> # line starts invariant:
>                    if not eval(line[LEN_KEYWORD_INVARIANT:]):
>                        self._violationProc(self._wrapped_object,line)
>                            
>                elif line.startswith(KEYWORD_TYPE):
>                    for p in getParameters(line[LEN_KEYWORD_TYPE:]):
>                        arg = getattr(self._wrapped_object, p[0])
>                        if not checkType(arg,p[1]):
>                           
>self._violationProc(self._wrapped_object,line)
>
>    
>                    
>    class DbCMethodWrapper:         
>        """ A Wrapper for methods """
>        def __init__(self, object, method):
>            """ param: object:InstanceType, method:MethodType  """
>            self.wrapped_object = object
>            self.method = method
>            if method.__doc__:                                      #
>has documentation string
>                self.docstr = method.__doc__.splitlines()             
> # Turn it into a list
>                self.docstr = map(lambda x:x.strip(),self.docstr)     
> # remove leading/trailing whitespace
>            else:
>                self.docstr = None
>            self.violationProc =
>self.wrapped_object._dbc_wrapper._violationProc
>            self.private_vars = {}
>            
>        def __call__(self,*args):
>            """ param: args:AnyType
>                Only bother to call pre and post condition if you have
>a doc string for this method
>                However we can still check the class invariants at the
>end.
>            """
>            if not self.docstr:
>                retval = apply(self.method,args)			                #
>Call the method
>            else:    
>                self.private_vars = self.getWrappedClassPrivateVars() 
>     # shallow-copy
>                self.checkBefore(args)                                
>     # Pre-conditions
>                self.copyPrivateVarsToOld()                           
>     # Copy members to old
>                #self.displayPrivateVars()
>                retval = apply(self.method,args)			                #
>Call the method
>
>                self.private_vars.update(self.getWrappedClassPrivateVars())
>                self.checkAfter()					                        #
>Post-conditions
>                del self.private_vars
>                self.private_vars = {}
>            self.wrapped_object._dbc_wrapper.checkClassInvariants()   
>     #always check class invars !
>            return retval
>
>        def getWrappedClassPrivateVars(self):
>            retval = {}
>            for key in self.wrapped_object.__dict__.keys():
>                if not key in
>['docstr','_dbc_wrapper','__builtins__']: # dont whant to copy these
>objects
>                    retval[key] = self.wrapped_object.__dict__[key]
>            return retval
>        
>        class DbCDictWrapper:
>            """ class that allows us to use the old. and para.
>notation in the DbC expressions."""
>            pass
>        
>        def copyPrivateVarsToOld(self):
>            """ Copy methods private vars to old """    
>            if self.private_vars:                                   #
>may not have any
>                old =
>DbCClassWrapper.DbCMethodWrapper.DbCDictWrapper()
>                if self.private_vars.has_key('__builtins__'):
>                    del self.private_vars['__builtins__']
>                for key in self.private_vars.keys():
>                    if not key in ['param']:                        #
>dont copy param over
>                       
>setattr(old,key,copy.deepcopy(self.private_vars[key]))     # deep-copy
>                self.private_vars['old'] = old
>
>        def displayPrivateVars(self):
>            """ for debug """
>            print '-'*60
>            for i in self.private_vars.keys():
>                if not i in ['__builtins__']:
>                    print '[',i,'] =',self.private_vars[i]
>                    if i in ['param','old']:
>                            print '  ',self.private_vars[i].__dict__
>            print '-'*60
>
>        def checkBefore(self,*args):
>            """ param: args:AnyType
>                Pre-conditions 
>                Check param, require, invars
>                require: self.docstr != None
>            """
>            args = args[0]                                          #
>unpack args from tuple
>            param = DbCClassWrapper.DbCMethodWrapper.DbCDictWrapper()
>            self.private_vars['param'] = param
>            for line in self.docstr:				                # For each
>line...
>                if line.startswith(KEYWORD_PARAMETER):
>                    parameters =
>getParameters(line[LEN_KEYWORD_PARAMETER:])
>                    #print 'parameters',parameters
>                    for i in range(len(parameters)):                  
>         # for each param type pair
>                        #if checkType(type(args[i]) !=
>asType[parameters[i][1]]:           # Check type is correct
>                        if not checkType(args[i],parameters[i][1]):   
>       # Check type is correct
>                           
>self.violationProc(self.wrapped_object,line)
>                       
>setattr(self.private_vars['param'],parameters[i][0],args[i])        #
>assign to params
>        
>                elif line.startswith(KEYWORD_REQUIRE):
>                    if not
>eval(line[LEN_KEYWORD_REQUIRE:],self.private_vars):
>                        self.violationProc(self.wrapped_object,line)
>                
>                elif line.startswith(KEYWORD_INVARIANT): 
>                    if not eval(line[LEN_KEYWORD_INVARIANT:],
>self.private_vars):
>                        self.violationProc(self.wrapped_object,line)
>            
>                
>        def checkAfter(self):
>            """ Post-conditions 
>                This checks the ensure and method invars
>                require: self.docstr != None
>            """
>            for line in self.docstr:
>                if line.startswith(KEYWORD_ENSURE):
>                    if not eval(line[LEN_KEYWORD_ENSURE:],
>self.private_vars):
>                        self.displayPrivateVars()
>            
>                       
>self.violationProc(self.wrapped_object,':"'+line+'"')
>                
>                elif line.startswith(KEYWORD_INVARIANT):
>                    if not
>eval(line[LEN_KEYWORD_INVARIANT:],self.private_vars):
>                       
>self.violationProc(self.wrapped_object,':"'+line+'"')
>            
>            
>            
>DBC_LAST_LINE = 297     # Don not dump the stack any further than this
>line number for DbCErrors.
>
>    
>if __name__ == '__main__':
>    class Stack:
>        """ A simple test class
>            invar: list != None                         # Should work
>            type: list: ListType                        # Should work
>        """
>        def __init__(self):
>            self.list = []
>            
>        def ppush(self,val):
>            """ param:   val:IntType                        # Should
>work
>                require: param.val >1                       # Should
>work
>                invar:  list != None                        # Should
>work
>                ensure: len(list) == len(old.list)+1        # Should
>work
>                """
>            self.list.append(val)
>        
>        def ppop(self):
>            """ invar:  list != None                        # Should
>work
>                ensure: len(list) == len(old.list)-1        # Should
>work
>                """
>    
>            return self.list.pop()
>
>        def inc(self,val):
>            """ param:  val:IntType
>                require: another > 1                     # Should pass
>                ensure: old.another > 100                     # Should
>fail
>                ensure: another == old.another + 1                    
># Should pass
>                ensure: another == old.another + 10                   
> # Should fail
>                """
>            self.another = self.another+ 1
>               
>    import unittest    
>    class DesignByContractTests(unittest.TestCase):
>        def setUp(self):
>            def onContractViolation(object,line):
>                raise DbCViolationError(line)               # Raise
>exception
>            aStack = Stack()                                # Call the
>constructor for class
>            self.ws = DbCClassWrapper(aStack,onContractViolation)     
>   # And wrap it
>    	    
>        def checkPushValidValue(self):
>            self.ws.ppush(10)
>            
>        def checkPushInValidType(self):
>            self.failUnlessRaises(DbCViolationError,self.ws.ppush,'asd')
>
>        def checkPushInValidValue(self):
>            self.failUnlessRaises(DbCViolationError,self.ws.ppush,-1)
>
>        def checkPopValid(self):
>            self.ws.ppush(10)
>            assert self.ws.ppop() == 10
>                    
>            
>    suite = unittest.makeSuite(DesignByContractTests,'check')
>    retval = unittest.TextTestRunner().run(suite)




More information about the Python-list mailing list