[pypy-svn] r4293 - in pypy/branch/typeinference/pypy: annotation translator translator/test

arigo at codespeak.net arigo at codespeak.net
Thu May 6 16:39:01 CEST 2004


Author: arigo
Date: Thu May  6 16:39:00 2004
New Revision: 4293

Modified:
   pypy/branch/typeinference/pypy/annotation/binaryop.py
   pypy/branch/typeinference/pypy/annotation/factory.py
   pypy/branch/typeinference/pypy/annotation/model.py
   pypy/branch/typeinference/pypy/annotation/unaryop.py
   pypy/branch/typeinference/pypy/translator/annrpython.py
   pypy/branch/typeinference/pypy/translator/genpyrex.py
   pypy/branch/typeinference/pypy/translator/test/snippet.py
   pypy/branch/typeinference/pypy/translator/test/test_annrpython.py
Log:
The class/instance model.
Hopefully the last round of changes all over the files.
More tests pass, including two new tests with inheritance.



Modified: pypy/branch/typeinference/pypy/annotation/binaryop.py
==============================================================================
--- pypy/branch/typeinference/pypy/annotation/binaryop.py	(original)
+++ pypy/branch/typeinference/pypy/annotation/binaryop.py	Thu May  6 16:39:00 2004
@@ -6,8 +6,9 @@
 from pypy.annotation.model import SomeObject, SomeInteger, SomeBool
 from pypy.annotation.model import SomeString, SomeList
 from pypy.annotation.model import SomeTuple, SomeImpossibleValue
+from pypy.annotation.model import SomeInstance
 from pypy.annotation.model import set, setunion, missing_operation
-from pypy.annotation.factory import NeedGeneralization
+from pypy.annotation.factory import BlockedInference
 
 
 # XXX unify this with ObjSpace.MethodTable
@@ -107,9 +108,11 @@
     def getitem((lst1, int2)):
         return lst1.s_item
 
-    def setitem((lst1, int2), value):
-        if not lst1.s_item.contains(value):
-            raise NeedGeneralization(lst1, value)
+    def setitem((lst1, int2), s_value):
+        if not lst1.s_item.contains(s_value):
+            for factory in lst1.factories:
+                factory.generalize(s_value)
+            raise BlockedInference(lst1.factories)
 
 
 class __extend__(pairtype(SomeInteger, SomeList)):
@@ -118,6 +121,13 @@
         return lst2
 
 
+class __extend__(pairtype(SomeInstance, SomeInstance)):
+
+    def union((ins1, ins2)):
+        basedef = ins1.classdef.commonbase(ins2.classdef)
+        return SomeInstance(basedef)
+
+
 class __extend__(pairtype(SomeImpossibleValue, SomeObject)):
     def union((imp1, obj2)):
         return obj2

Modified: pypy/branch/typeinference/pypy/annotation/factory.py
==============================================================================
--- pypy/branch/typeinference/pypy/annotation/factory.py	(original)
+++ pypy/branch/typeinference/pypy/annotation/factory.py	Thu May  6 16:39:00 2004
@@ -6,24 +6,19 @@
 
 """
 
+from __future__ import generators
 from pypy.annotation.pairtype import pair
 from pypy.annotation.model import SomeImpossibleValue, SomeList
+from pypy.annotation.model import SomeObject, SomeInstance
 
 
 class BlockedInference(Exception):
     """This exception signals the type inference engine that the situation
     is currently blocked, and that it should try to progress elsewhere."""
-    invalidatefactories = ()  # factories that need to be invalidated
 
-class NeedGeneralization(BlockedInference):
-    """The mutable object s_mutable requires generalization.
-    The *args are passed to the generalize() method of the factory."""
-
-    def __init__(self, s_mutable, *args):
-        BlockedInference.__init__(self, s_mutable, *args)
-        for factory in s_mutable.factories:
-            factory.generalize(*args)
-        self.invalidatefactories = s_mutable.factories
+    def __init__(self, factories = ()):
+        # factories that need to be invalidated
+        self.invalidatefactories = factories
 
 
 #
@@ -38,3 +33,86 @@
 
     def generalize(self, s_new_item):
         self.s_item = pair(self.s_item, s_new_item).union()
+
+
+class InstanceFactory:
+
+    def __init__(self, cls, userclasses):
+        self.classdef = getclassdef(cls, userclasses)
+        self.classdef.instancefactories[self] = True
+
+    def create(self):
+        return SomeInstance(self.classdef)
+
+
+class ClassDef:
+    "Wraps a user class."
+
+    def __init__(self, cls, userclasses):
+        self.attrs = {}          # attrs is updated with new information
+        self.revision = 0        # which increases the revision number
+        self.instancefactories = {}
+        self.cls = cls
+        self.subdefs = {}
+        assert len(cls.__bases__) <= 1, "single inheritance only right now"
+        if cls.__bases__:
+            base = cls.__bases__[0]
+        else:
+            base = object
+        self.basedef = getclassdef(base, userclasses)
+        if self.basedef:
+            self.basedef.subdefs[cls] = self
+
+    def __repr__(self):
+        return '<ClassDef %s.%s>' % (self.cls.__module__, self.cls.__name__)
+
+    def commonbase(self, other):
+        while other is not None and not issubclass(self.cls, other.cls):
+            other = other.basedef
+        return other
+
+    def getmro(self):
+        while self is not None:
+            yield self
+            self = self.basedef
+
+    def getallsubdefs(self):
+        pending = [self]
+        seen = {}
+        for clsdef in pending:
+            yield clsdef
+            for sub in clsdef.subdefs.values():
+                if sub not in seen:
+                    pending.append(sub)
+                    seen[sub] = True
+
+    def getallfactories(self):
+        factories = {}
+        for clsdef in self.getallsubdefs():
+            factories.update(clsdef.instancefactories)
+        return factories
+
+    def generalize(self, attr, s_value):
+        # we make sure that an attribute never appears both in a class
+        # and in some subclass, in two steps:
+        # (1) assert that the attribute is in no superclass
+        for clsdef in self.getmro():
+            assert clsdef is self or attr not in clsdef.attrs
+        # (2) remove the attribute from subclasses
+        for subdef in self.getallsubdefs():
+            if attr in subdef.attrs:
+                s_value = pair(s_value, subdef.attrs[attr]).union()
+                del subdef.attrs[attr]
+            # bump the revision number of this class and all subclasses
+            subdef.revision += 1
+        self.attrs[attr] = s_value
+
+
+def getclassdef(cls, cache):
+    if cls is object:
+        return None
+    try:
+        return cache[cls]
+    except KeyError:
+        cache[cls] = ClassDef(cls, cache)
+        return cache[cls]

Modified: pypy/branch/typeinference/pypy/annotation/model.py
==============================================================================
--- pypy/branch/typeinference/pypy/annotation/model.py	(original)
+++ pypy/branch/typeinference/pypy/annotation/model.py	Thu May  6 16:39:00 2004
@@ -45,7 +45,7 @@
         kwds = ', '.join(['%s=%r' % item for item in self.__dict__.items()])
         return '%s(%s)' % (self.__class__.__name__, kwds)
     def contains(self, other):
-        return pair(self, other).union() == self
+        return self == other or pair(self, other).union() == self
     def is_constant(self):
         return hasattr(self, 'const')
 
@@ -81,6 +81,13 @@
     def len(self):
         return immutablevalue(len(self.items))
 
+class SomeInstance(SomeObject):
+    "Stands for an instance of a (user-defined) class."
+    def __init__(self, classdef):
+        self.classdef = classdef
+        self.knowntype = classdef.cls
+        self.revision = classdef.revision
+
 class SomeImpossibleValue(SomeObject):
     """The empty set.  Instances are placeholders for objects that
     will never show up at run-time, e.g. elements of an empty list."""

Modified: pypy/branch/typeinference/pypy/annotation/unaryop.py
==============================================================================
--- pypy/branch/typeinference/pypy/annotation/unaryop.py	(original)
+++ pypy/branch/typeinference/pypy/annotation/unaryop.py	Thu May  6 16:39:00 2004
@@ -5,11 +5,13 @@
 from pypy.annotation.pairtype import pair, pairtype
 from pypy.annotation.model import SomeObject, SomeInteger, SomeBool
 from pypy.annotation.model import SomeString, SomeList
-from pypy.annotation.model import SomeTuple
+from pypy.annotation.model import SomeTuple, SomeImpossibleValue
+from pypy.annotation.model import SomeInstance
 from pypy.annotation.model import set, setunion, missing_operation
+from pypy.annotation.factory import BlockedInference
 
 
-UNARY_OPERATIONS = set(['len', 'is_true'])
+UNARY_OPERATIONS = set(['len', 'is_true', 'getattr', 'setattr'])
 
 for opname in UNARY_OPERATIONS:
     missing_operation(SomeObject, opname)
@@ -22,3 +24,43 @@
 
     def is_true(obj):
         return SomeBool()
+
+
+class __extend__(SomeInstance):
+
+    def currentdef(ins):
+        if ins.revision != ins.classdef.revision:
+            print ins.revision, ins.classdef.revision
+            raise BlockedInference()
+        return ins.classdef
+
+    def getattr(ins, attr):
+        if attr.is_constant() and isinstance(attr.const, str):
+            attr = attr.const
+            # look for the attribute in the MRO order
+            for clsdef in ins.currentdef().getmro():
+                if attr in clsdef.attrs:
+                    return clsdef.attrs[attr]
+            # maybe the attribute exists in some subclass? if so, lift it
+            clsdef = ins.classdef
+            clsdef.generalize(attr, SomeImpossibleValue())
+            raise BlockedInference(clsdef.getallfactories())
+        return SomeObject()
+
+    def setattr(ins, attr, s_value):
+        if attr.is_constant() and isinstance(attr.const, str):
+            attr = attr.const
+            for clsdef in ins.currentdef().getmro():
+                if attr in clsdef.attrs:
+                    # look for the attribute in ins.classdef or a parent class
+                    s_existing = clsdef.attrs[attr]
+                    if s_existing.contains(s_value):
+                        return   # already general enough, nothing to do
+                    break
+            else:
+                # if the attribute doesn't exist yet, create it here
+                clsdef = ins.classdef
+            # create or update the attribute in clsdef
+            clsdef.generalize(attr, s_value)
+            raise BlockedInference(clsdef.getallfactories())
+        return SomeObject()

Modified: pypy/branch/typeinference/pypy/translator/annrpython.py
==============================================================================
--- pypy/branch/typeinference/pypy/translator/annrpython.py	(original)
+++ pypy/branch/typeinference/pypy/translator/annrpython.py	Thu May  6 16:39:00 2004
@@ -1,10 +1,10 @@
 from __future__ import generators
 
-from types import FunctionType
+from types import FunctionType, ClassType
 from pypy.annotation import model as annmodel
 from pypy.annotation.model import pair
-from pypy.annotation.factory import ListFactory
-from pypy.annotation.factory import BlockedInference, NeedGeneralization
+from pypy.annotation.factory import ListFactory, InstanceFactory
+from pypy.annotation.factory import BlockedInference
 from pypy.objspace.flow.model import Variable, Constant, UndefinedConstant
 from pypy.objspace.flow.model import SpaceOperation
 
@@ -23,6 +23,7 @@
         self.annotated = {}      # set of blocks already seen
         self.creationpoints = {} # map positions-in-blocks to Factories
         self.translator = translator
+        self.userclasses = {}    # set of user classes
 
     #___ convenience high-level interface __________________
 
@@ -55,6 +56,18 @@
             raise TypeError, ("Variable or Constant instance expected, "
                               "got %r" % (variable,))
 
+    def getuserclasses(self):
+        """Return a set of known user classes."""
+        return self.userclasses
+
+    def getuserattributes(self, cls):
+        """Enumerate the attributes of the given user class, as Variable()s."""
+        clsdef = self.userclasses[cls]
+        for attr, s_value in clsdef.attrs.items():
+            v = Variable(name=attr)
+            self.bindings[v] = s_value
+            yield v
+
 
     #___ medium-level interface ____________________________
 
@@ -73,7 +86,7 @@
             self.processblock(block, cells)
         if False in self.annotated.values():
             raise AnnotatorError('%d blocks are still blocked' %
-                                 len(self.annotated.values().count(False)))
+                                 self.annotated.values().count(False))
 
     def binding(self, arg):
         "Gives the SomeValue corresponding to the given Variable or Constant."
@@ -86,10 +99,6 @@
         else:
             raise TypeError, 'Variable or Constant expected, got %r' % (arg,)
 
-    def constant(self, value):
-        "Turn a value into a SomeValue with the proper annotations."
-        return annmodel.immutablevalue(arg.value)
-
 
     #___ simplification (should be moved elsewhere?) _______
 
@@ -149,12 +158,13 @@
             try:
                 self.flowin(block)
             except BlockedInference, e:
+                #print '_'*60
+                #print 'Blocked at %r:' % (self.curblockpos,)
+                #import traceback, sys
+                #traceback.print_tb(sys.exc_info()[2])
                 self.annotated[block] = False   # failed, hopefully temporarily
                 for factory in e.invalidatefactories:
                     self.reflowpendingblock(factory.block)
-            else:
-                return True   # progressed
-        return False
 
     def reflowpendingblock(self, block):
         self.pendingblocks.append((block, None))
@@ -178,21 +188,19 @@
 
     def flowin(self, block):
         #print 'Flowing', block, [self.binding(a) for a in block.inputargs]
-        if block.operations:
-            for i in range(len(block.operations)):
-                self.curblockpos = block, i
-                self.consider_op(block.operations[i])
-            del self.curblockpos
+        for i in range(len(block.operations)):
+            self.curblockpos = block, i
+            self.consider_op(block.operations[i])
         for link in block.exits:
             cells = [self.binding(a) for a in link.args]
             self.addpendingblock(link.target, cells)
 
-    def getfactory(self, factorycls):
+    def getfactory(self, factorycls, *factoryargs):
         try:
             factory = self.creationpoints[self.curblockpos]
         except KeyError:
             block = self.curblockpos[0]
-            factory = factorycls()
+            factory = factorycls(*factoryargs)
             factory.block = block
             self.creationpoints[self.curblockpos] = factory
         # self.curblockpos is an arbitrary key that identifies a specific
@@ -278,11 +286,13 @@
         elif isinstance(func, FunctionType) and self.translator:
             args = self.decode_simple_call(s_varargs, s_kwargs)
             return self.translator.consider_call(self, func, args)
+        elif (isinstance(func, (type, ClassType)) and
+              func.__module__ != '__builtin__'):
+            # XXX flow into __init__/__new__
+            factory = self.getfactory(InstanceFactory, func, self.userclasses)
+            return factory.create()
         elif isinstance(func,type):
             return annmodel.valueoftype(func)
-##            # XXX flow into __init__/__new__
-##            if func.__module__ != '__builtin__':
-##                self.userclasses.setdefault(func, {})
         return annmodel.SomeObject()
 
 

Modified: pypy/branch/typeinference/pypy/translator/genpyrex.py
==============================================================================
--- pypy/branch/typeinference/pypy/translator/genpyrex.py	(original)
+++ pypy/branch/typeinference/pypy/translator/genpyrex.py	Thu May  6 16:39:00 2004
@@ -225,8 +225,8 @@
         vartype = self.get_type(var)
         if vartype in (int, bool):
             prefix = "i_"
-        #elif self.annotator and vartype in self.annotator.getuserclasses():
-        #    prefix = "p_"
+        elif self.annotator and vartype in self.annotator.getuserclasses():
+            prefix = "p_"
         else:
             prefix = ""
         return prefix + var.name
@@ -235,8 +235,8 @@
         vartype = self.get_type(var)
         if vartype in (int, bool):
             ctype = "int"
-        #elif self.annotator and vartype in self.annotator.getuserclasses():
-        #    ctype = self.get_classname(vartype)
+        elif self.annotator and vartype in self.annotator.getuserclasses():
+            ctype = self.get_classname(vartype)
         else:
             ctype = "object"
 
@@ -332,20 +332,20 @@
         if self.annotator:
             self.lines = []
             self.indent = 0
-##            for cls in self.annotator.getuserclasses():
-##                self.putline("cdef class %s:" % self.get_classname(cls))
-##                self.indent += 1
-##                empty = True
-##                for var in self.annotator.getuserattributes(cls):
-##                    vartype, varname = self._paramvardecl(var)
-##                    varname = var.name   # no 'i_' prefix
-##                    self.putline("cdef %s %s" % (vartype, varname))
-##                    empty = False
-##                else:
-##                    if empty:
-##                        self.putline("pass")
-##                self.indent -= 1
-##                self.putline("")
+            for cls in self.annotator.getuserclasses():
+                self.putline("cdef class %s:" % self.get_classname(cls))
+                self.indent += 1
+                empty = True
+                for var in self.annotator.getuserattributes(cls):
+                    vartype, varname = self._paramvardecl(var)
+                    varname = var.name   # no 'i_' prefix
+                    self.putline("cdef %s %s" % (vartype, varname))
+                    empty = False
+                else:
+                    if empty:
+                        self.putline("pass")
+                self.indent -= 1
+                self.putline("")
             return '\n'.join(self.lines)
         else:
             return ''

Modified: pypy/branch/typeinference/pypy/translator/test/snippet.py
==============================================================================
--- pypy/branch/typeinference/pypy/translator/test/snippet.py	(original)
+++ pypy/branch/typeinference/pypy/translator/test/snippet.py	Thu May  6 16:39:00 2004
@@ -287,3 +287,22 @@
     c.a = 1
     c.a = 2
     return c.a
+
+class D(C): pass
+class E(C): pass
+
+def inheritance1():
+    d = D()
+    d.stuff = ()
+    e = E()
+    e.stuff = -12
+    e.stuff = 3
+    lst = [d, e]
+    return d.stuff, e.stuff
+
+def inheritance2():
+    d = D()
+    d.stuff = (-12, -12)
+    e = E()
+    e.stuff = (3, "world")
+    return C().stuff

Modified: pypy/branch/typeinference/pypy/translator/test/test_annrpython.py
==============================================================================
--- pypy/branch/typeinference/pypy/translator/test/test_annrpython.py	(original)
+++ pypy/branch/typeinference/pypy/translator/test/test_annrpython.py	Thu May  6 16:39:00 2004
@@ -150,6 +150,29 @@
         # result should be an integer
         self.assertEquals(a.gettype(graph.getreturnvar()), int)
 
+    def test_inheritance1(self):
+        translator = Translator(snippet.inheritance1)
+        graph = translator.getflowgraph()
+        a = RPythonAnnotator(translator)
+        a.build_types(graph, [])
+        # result should be exactly:
+        self.assertEquals(a.binding(graph.getreturnvar()),
+                          annmodel.SomeTuple([
+                              annmodel.SomeTuple([]),
+                              annmodel.SomeInteger()
+                              ]))
+
+    def test_inheritance2(self):
+        translator = Translator(snippet.inheritance2)
+        graph = translator.getflowgraph()
+        a = RPythonAnnotator(translator)
+        a.build_types(graph, [])
+        # result should be exactly:
+        self.assertEquals(a.binding(graph.getreturnvar()),
+                          annmodel.SomeTuple([
+                              annmodel.SomeInteger(),
+                              annmodel.SomeObject()
+                              ]))
 
 
 def g(n):


More information about the Pypy-commit mailing list