[pypy-svn] r17440 - in pypy/dist/pypy: annotation rpython rpython/test

arigo at codespeak.net arigo at codespeak.net
Sat Sep 10 14:40:05 CEST 2005


Author: arigo
Date: Sat Sep 10 14:40:02 2005
New Revision: 17440

Removed:
   pypy/dist/pypy/rpython/rconstantdict.py
   pypy/dist/pypy/rpython/remptydict.py
Modified:
   pypy/dist/pypy/annotation/builtin.py
   pypy/dist/pypy/rpython/rbuiltin.py
   pypy/dist/pypy/rpython/rclass.py
   pypy/dist/pypy/rpython/rdict.py
   pypy/dist/pypy/rpython/rint.py
   pypy/dist/pypy/rpython/rmodel.py
   pypy/dist/pypy/rpython/rstr.py
   pypy/dist/pypy/rpython/test/test_rdict.py
Log:
Sanitization of the multiple low-level dict implementations.
There is now only a generic DictRepr left.  This is slightly
less efficient e.g. for string keys, but we can think later
about regaining this performance.

Done by adding a get_ll_hash_function() to some Reprs, similar
to the already-existing get_ll_eq_function().  The dict lookup
uses these two functions generically.

The tests pass, but if something in rdict.py is not tested I
probably have left an oversight there.



Modified: pypy/dist/pypy/annotation/builtin.py
==============================================================================
--- pypy/dist/pypy/annotation/builtin.py	(original)
+++ pypy/dist/pypy/annotation/builtin.py	Sat Sep 10 14:40:02 2005
@@ -327,6 +327,9 @@
     cast_p = lltype.cast_pointer(PtrT.const, s_p.ll_ptrtype._defl())
     return SomePtr(ll_ptrtype=lltype.typeOf(cast_p))
 
+def cast_ptr_to_int(s_ptr): # xxx
+    return SomeInteger()
+
 def getRuntimeTypeInfo(T):
     assert T.is_constant()
     return immutablevalue(lltype.getRuntimeTypeInfo(T.const))
@@ -339,6 +342,7 @@
 BUILTIN_ANALYZERS[lltype.typeOf] = typeOf
 BUILTIN_ANALYZERS[lltype.nullptr] = nullptr
 BUILTIN_ANALYZERS[lltype.cast_pointer] = cast_pointer
+BUILTIN_ANALYZERS[lltype.cast_ptr_to_int] = cast_ptr_to_int
 BUILTIN_ANALYZERS[lltype.getRuntimeTypeInfo] = getRuntimeTypeInfo
 BUILTIN_ANALYZERS[lltype.runtime_type_info] = runtime_type_info
 

Modified: pypy/dist/pypy/rpython/rbuiltin.py
==============================================================================
--- pypy/dist/pypy/rpython/rbuiltin.py	(original)
+++ pypy/dist/pypy/rpython/rbuiltin.py	Sat Sep 10 14:40:02 2005
@@ -254,6 +254,12 @@
     return hop.genop('cast_pointer', [v_input],    # v_type implicit in r_result
                      resulttype = hop.r_result.lowleveltype)
 
+def rtype_cast_ptr_to_int(hop):
+    assert isinstance(hop.args_r[0], rptr.PtrRepr)
+    vlist = hop.inputargs(hop.args_r[0])
+    return hop.genop('cast_ptr_to_int', vlist,
+                     resulttype = lltype.Signed)
+
 def rtype_runtime_type_info(hop):
     assert isinstance(hop.args_r[0], rptr.PtrRepr)
     vlist = hop.inputargs(hop.args_r[0])
@@ -263,6 +269,7 @@
 
 BUILTIN_TYPER[lltype.malloc] = rtype_malloc
 BUILTIN_TYPER[lltype.cast_pointer] = rtype_cast_pointer
+BUILTIN_TYPER[lltype.cast_ptr_to_int] = rtype_cast_ptr_to_int
 BUILTIN_TYPER[lltype.typeOf] = rtype_const_result
 BUILTIN_TYPER[lltype.nullptr] = rtype_const_result
 BUILTIN_TYPER[lltype.getRuntimeTypeInfo] = rtype_const_result

Modified: pypy/dist/pypy/rpython/rclass.py
==============================================================================
--- pypy/dist/pypy/rpython/rclass.py	(original)
+++ pypy/dist/pypy/rpython/rclass.py	Sat Sep 10 14:40:02 2005
@@ -202,6 +202,9 @@
         #
         return getclassrepr(self.rtyper, subclassdef).getvtable()
 
+    def get_ll_eq_function(self):
+        return None
+
     def getvtable(self, cast_to_typeptr=True):
         """Return a ptr to the vtable of this type."""
         if self.vtable is None:

Modified: pypy/dist/pypy/rpython/rdict.py
==============================================================================
--- pypy/dist/pypy/rpython/rdict.py	(original)
+++ pypy/dist/pypy/rpython/rdict.py	Sat Sep 10 14:40:02 2005
@@ -1,19 +1,26 @@
 from pypy.annotation.pairtype import pairtype
 from pypy.annotation import model as annmodel
 from pypy.objspace.flow.model import Constant
-from pypy.rpython import rmodel, lltype, rstr
+from pypy.rpython import rmodel, lltype
 from pypy.rpython.rarithmetic import r_uint
-from pypy.rpython import rlist, rconstantdict, remptydict
+from pypy.rpython import rlist
 from pypy.rpython import robject
 
 # ____________________________________________________________
 #
-#  pseudo implementation of RPython dictionary (this is per
-#  dictvalue type): 
+#  generic implementation of RPython dictionary, with parametric DICTKEY and
+#  DICTVALUE types.
+#
+#  XXX this should be re-optimized for specific types of keys; e.g.
+#      for string keys we don't need the two boolean flags but can use
+#      a NULL and a special 'dummy' keys.  Similarily, for immutable dicts,
+#      the array should be inlined and num_pristine_entries is not needed.
 #
 #    struct dictentry {
-#        struct STR *key; 
-#        DICTVALUE value;  
+#        DICTSTR key;
+#        bool valid;      # to mark if the entry is filled
+#        bool everused;   # to mark if the entry is or has ever been filled
+#        DICTVALUE value;
 #    }
 #    
 #    struct dicttable {
@@ -26,66 +33,70 @@
 
 class __extend__(annmodel.SomeDict):
     def rtyper_makerepr(self, rtyper):
-        s_key = self.dictdef.dictkey.s_value
-        s_value = self.dictdef.dictvalue.s_value
-        if isinstance(s_key, annmodel.SomeString): 
-            if s_key.can_be_none():
-                raise rmodel.TyperError("cannot make repr of dict with "
-                                        "string-or-None keys")
-            dictvalue = self.dictdef.dictvalue 
-            return StrDictRepr(lambda: rtyper.getrepr(dictvalue.s_value), 
-                               dictvalue)
-        elif isinstance(s_key, (annmodel.SomeInteger,
-                                annmodel.SomeUnicodeCodePoint)):
-            dictkey = self.dictdef.dictkey
-            dictvalue = self.dictdef.dictvalue 
-            return rconstantdict.ConstantDictRepr(
-                        rtyper.getrepr(dictkey.s_value), 
-                        rtyper.getrepr(dictvalue.s_value))
-        elif isinstance(s_key, annmodel.SomeImpossibleValue):
-            return remptydict.emptydict_repr
-        elif (s_key.__class__ is annmodel.SomeObject and s_key.knowntype == object and
-              s_value.__class__ is annmodel.SomeObject and s_value.knowntype == object):
+        dictkey   = self.dictdef.dictkey
+        dictvalue = self.dictdef.dictvalue
+        s_key     = dictkey  .s_value
+        s_value   = dictvalue.s_value
+        if (s_key.__class__ is annmodel.SomeObject and s_key.knowntype == object and
+            s_value.__class__ is annmodel.SomeObject and s_value.knowntype == object):
             return robject.pyobj_repr
-        else: 
-            raise rmodel.TyperError("cannot make repr of %r" %(self.dictdef,))
+        else:
+            return DictRepr(lambda: rtyper.getrepr(s_key),
+                            lambda: rtyper.getrepr(s_value),
+                            dictkey,
+                            dictvalue)
 
     def rtyper_makekey(self):
         return (self.__class__, self.dictdef.dictkey, self.dictdef.dictvalue)
 
-class StrDictRepr(rmodel.Repr):
+class DictRepr(rmodel.Repr):
 
-    def __init__(self, value_repr, dictvalue=None): 
-        self.STRDICT = lltype.GcForwardReference()
-        self.lowleveltype = lltype.Ptr(self.STRDICT) 
+    def __init__(self, key_repr, value_repr, dictkey=None, dictvalue=None):
+        self.DICT = lltype.GcForwardReference()
+        self.lowleveltype = lltype.Ptr(self.DICT)
+        if not isinstance(key_repr, rmodel.Repr):  # not computed yet, done by setup()
+            assert callable(key_repr)
+            self._key_repr_computer = key_repr 
+        else:
+            self.key_repr = key_repr  
         if not isinstance(value_repr, rmodel.Repr):  # not computed yet, done by setup()
             assert callable(value_repr)
             self._value_repr_computer = value_repr 
         else:
             self.value_repr = value_repr  
+        self.dictkey = dictkey
         self.dictvalue = dictvalue
         self.dict_cache = {}
         # setup() needs to be called to finish this initialization
 
     def _setup_repr(self):
+        if 'key_repr' not in self.__dict__:
+            self.key_repr = self._key_repr_computer()
         if 'value_repr' not in self.__dict__:
             self.value_repr = self._value_repr_computer()
-        if isinstance(self.STRDICT, lltype.GcForwardReference):
+        if isinstance(self.DICT, lltype.GcForwardReference):
+            self.DICTKEY = self.key_repr.lowleveltype
             self.DICTVALUE = self.value_repr.lowleveltype
             self.DICTENTRY = lltype.Struct("dictentry", 
-                        ("key", lltype.Ptr(rstr.STR)), 
-                        ('value', self.DICTVALUE))
+                                ("key", self.DICTKEY),
+                                ("valid", lltype.Bool),
+                                ("everused", lltype.Bool),
+                                ("value", self.DICTVALUE))
             self.DICTENTRYARRAY = lltype.GcArray(self.DICTENTRY)
-            self.STRDICT.become(lltype.GcStruct("dicttable", 
+            self.DICT.become(lltype.GcStruct("dicttable",
                                 ("num_items", lltype.Signed), 
                                 ("num_pristine_entries", lltype.Signed), 
                                 ("entries", lltype.Ptr(self.DICTENTRYARRAY))))
+        if 'll_keyhash' not in self.__dict__:
+            # figure out which functions must be used to hash and compare keys
+            self.ll_keyeq   = self.key_repr.get_ll_eq_function()   # can be None
+            self.ll_keyhash = self.key_repr.get_ll_hash_function()
 
     def convert_const(self, dictobj):
         # get object from bound dict methods
         #dictobj = getattr(dictobj, '__self__', dictobj) 
         if dictobj is None:
-            return nullptr(self.STRDICT)
+            return nullptr(self.DICT)
         if not isinstance(dictobj, dict):
             raise TyperError("expected a dict: %r" % (dictobj,))
         try:
@@ -93,32 +104,33 @@
             return self.dict_cache[key]
         except KeyError:
             self.setup()
-            l_dict = ll_newstrdict(self.lowleveltype) 
+            l_dict = ll_newdict(self.lowleveltype)
             self.dict_cache[key] = l_dict 
-            r_key = rstr.string_repr 
+            r_key = self.key_repr
             r_value = self.value_repr
             for dictkey, dictvalue in dictobj.items():
                 llkey = r_key.convert_const(dictkey)
                 llvalue = r_value.convert_const(dictvalue)
-                ll_strdict_setitem(l_dict, llkey, llvalue)
-            return l_dict 
+                ll_dict_setitem(l_dict, llkey, llvalue, self)
+            return l_dict
 
     def rtype_len(self, hop):
         v_dict, = hop.inputargs(self)
-        return hop.gendirectcall(ll_strdict_len, v_dict)
+        return hop.gendirectcall(ll_dict_len, v_dict)
 
     def rtype_is_true(self, hop):
         v_dict, = hop.inputargs(self)
-        return hop.gendirectcall(ll_strdict_is_true, v_dict)
+        return hop.gendirectcall(ll_dict_is_true, v_dict)
 
     def make_iterator_repr(self, *variant):
-        return StrDictIteratorRepr(self, *variant)
+        return DictIteratorRepr(self, *variant)
 
     def rtype_method_get(self, hop):
-        v_dict, v_key, v_default = hop.inputargs(self, rstr.string_repr,
+        v_dict, v_key, v_default = hop.inputargs(self, self.key_repr,
                                                  self.value_repr)
+        crepr = hop.inputconst(lltype.Void, self)
         hop.exception_cannot_occur()
-        return hop.gendirectcall(ll_get, v_dict, v_key, v_default)
+        return hop.gendirectcall(ll_get, v_dict, v_key, v_default, crepr)
 
     def rtype_method_copy(self, hop):
         v_dict, = hop.inputargs(self)
@@ -127,8 +139,9 @@
 
     def rtype_method_update(self, hop):
         v_dic1, v_dic2 = hop.inputargs(self, self)
+        crepr = hop.inputconst(lltype.Void, self)
         hop.exception_cannot_occur()
-        return hop.gendirectcall(ll_update, v_dic1, v_dic2)
+        return hop.gendirectcall(ll_update, v_dic1, v_dic2, crepr)
 
     def _rtype_method_kvi(self, hop, spec):
         v_dic, = hop.inputargs(self)
@@ -149,69 +162,67 @@
 
     def rtype_method_iterkeys(self, hop):
         hop.exception_cannot_occur()
-        return StrDictIteratorRepr(self, "keys").newiter(hop)
+        return DictIteratorRepr(self, "keys").newiter(hop)
 
     def rtype_method_itervalues(self, hop):
         hop.exception_cannot_occur()
-        return StrDictIteratorRepr(self, "values").newiter(hop)
+        return DictIteratorRepr(self, "values").newiter(hop)
 
     def rtype_method_iteritems(self, hop):
         hop.exception_cannot_occur()
-        return StrDictIteratorRepr(self, "items").newiter(hop)
+        return DictIteratorRepr(self, "items").newiter(hop)
 
     def rtype_method_clear(self, hop):
         v_dict, = hop.inputargs(self)
         hop.exception_cannot_occur()
         return hop.gendirectcall(ll_clear, v_dict)
 
-class __extend__(pairtype(StrDictRepr, rmodel.StringRepr)): 
+class __extend__(pairtype(DictRepr, rmodel.Repr)): 
 
-    def rtype_getitem((r_dict, r_string), hop):
-        v_dict, v_key = hop.inputargs(r_dict, rstr.string_repr)
+    def rtype_getitem((r_dict, r_key), hop):
+        v_dict, v_key = hop.inputargs(r_dict, r_dict.key_repr)
+        crepr = hop.inputconst(lltype.Void, r_dict)
         hop.has_implicit_exception(KeyError)   # record that we know about it
         hop.exception_is_here()
-        return hop.gendirectcall(ll_strdict_getitem, v_dict, v_key)
+        return hop.gendirectcall(ll_dict_getitem, v_dict, v_key, crepr)
 
-    def rtype_delitem((r_dict, r_string), hop):
-        v_dict, v_key = hop.inputargs(r_dict, rstr.string_repr) 
+    def rtype_delitem((r_dict, r_key), hop):
+        v_dict, v_key = hop.inputargs(r_dict, r_dict.key_repr)
+        crepr = hop.inputconst(lltype.Void, r_dict)
         hop.has_implicit_exception(KeyError)   # record that we know about it
         hop.exception_is_here()
-        return hop.gendirectcall(ll_strdict_delitem, v_dict, v_key)
+        return hop.gendirectcall(ll_dict_delitem, v_dict, v_key, crepr)
 
-    def rtype_setitem((r_dict, r_string), hop):
-        v_dict, v_key, v_value = hop.inputargs(r_dict, rstr.string_repr, r_dict.value_repr) 
-        hop.gendirectcall(ll_strdict_setitem, v_dict, v_key, v_value)
-
-    def rtype_contains((r_dict, r_string), hop):
-        v_dict, v_key = hop.inputargs(r_dict, rstr.string_repr)
-        return hop.gendirectcall(ll_contains, v_dict, v_key)
+    def rtype_setitem((r_dict, r_key), hop):
+        v_dict, v_key, v_value = hop.inputargs(r_dict, r_dict.key_repr, r_dict.value_repr)
+        crepr = hop.inputconst(lltype.Void, r_dict)
+        hop.gendirectcall(ll_dict_setitem, v_dict, v_key, v_value, crepr)
+
+    def rtype_contains((r_dict, r_key), hop):
+        v_dict, v_key = hop.inputargs(r_dict, r_dict.key_repr)
+        crepr = hop.inputconst(lltype.Void, r_dict)
+        return hop.gendirectcall(ll_contains, v_dict, v_key, crepr)
         
-class __extend__(pairtype(StrDictRepr, StrDictRepr)):
+class __extend__(pairtype(DictRepr, DictRepr)):
     def convert_from_to((r_dict1, r_dict2), v, llops):
-        # check that we don't convert from StrDicts with
-        # different value types 
+        # check that we don't convert from Dicts with
+        # different key/value types 
+        if r_dict1.dictkey is None or r_dict2.dictkey is None:
+            return NotImplemented
+        if r_dict1.dictkey is not r_dict2.dictkey:
+            return NotImplemented
         if r_dict1.dictvalue is None or r_dict2.dictvalue is None:
             return NotImplemented
         if r_dict1.dictvalue is not r_dict2.dictvalue:
             return NotImplemented
         return v
 
-    #def rtype_add((self, _), hop):
-    #    v_lst1, v_lst2 = hop.inputargs(self, self)
-    #    return hop.gendirectcall(ll_concat, v_lst1, v_lst2)
-#
-#    def rtype_inplace_add((self, _), hop):
-#        v_lst1, v_lst2 = hop.inputargs(self, self)
-#        hop.gendirectcall(ll_extend, v_lst1, v_lst2)
-#        return v_lst1
-
 # ____________________________________________________________
 #
 #  Low-level methods.  These can be run for testing, but are meant to
 #  be direct_call'ed from rtyped flow graphs, which means that they will
 #  get flowed and annotated, mostly with SomePtr.
 
-deleted_entry_marker = lltype.malloc(rstr.STR, 0, immortal=True)
 def dum_keys(): pass
 def dum_values(): pass
 def dum_items():pass
@@ -219,114 +230,124 @@
                "values": dum_values,
                "items":  dum_items}
 
-def ll_strdict_len(d):
+def ll_dict_len(d):
     return d.num_items 
 
-def ll_strdict_is_true(d):
+def ll_dict_is_true(d):
     # check if a dict is True, allowing for None
     return bool(d) and d.num_items != 0
 
-def ll_strdict_getitem(d, key): 
-    entry = ll_strdict_lookup(d, key) 
-    if entry.key and entry.key != deleted_entry_marker: 
+def ll_dict_getitem(d, key, dictrepr):
+    entry = ll_dict_lookup(d, key, dictrepr)
+    if entry.valid:
         return entry.value 
     else: 
         raise KeyError 
 
-def ll_strdict_setitem(d, key, value): 
-    entry = ll_strdict_lookup(d, key)
-    if not entry.key: 
-        entry.key = key 
-        entry.value = value 
-        d.num_items += 1
+def ll_dict_setitem(d, key, value, dictrepr):
+    entry = ll_dict_lookup(d, key, dictrepr)
+    entry.value = value
+    if entry.valid:
+        return
+    entry.key = key 
+    entry.valid = True
+    d.num_items += 1
+    if not entry.everused:
+        entry.everused = True
         d.num_pristine_entries -= 1
         if d.num_pristine_entries <= len(d.entries) / 3:
-            ll_strdict_resize(d)
-    elif entry.key == deleted_entry_marker: 
-        entry.key = key 
-        entry.value = value 
-        d.num_items += 1
-    else:
-        entry.value = value 
+            ll_dict_resize(d, dictrepr)
 
-def ll_strdict_delitem(d, key): 
-    entry = ll_strdict_lookup(d, key)
-    if not entry.key or entry.key == deleted_entry_marker: 
-         raise KeyError
-    entry.key = deleted_entry_marker
+def ll_dict_delitem(d, key, dictrepr):
+    entry = ll_dict_lookup(d, key, dictrepr)
+    if not entry.valid:
+        raise KeyError
+    entry.valid = False
+    d.num_items -= 1
+    # clear the key and the value if they are pointers
+    keytype = lltype.typeOf(entry).TO.key
+    if isinstance(keytype, lltype.Ptr):
+        key = entry.key   # careful about destructor side effects
+        entry.key = lltype.nullptr(keytype.TO)
     valuetype = lltype.typeOf(entry).TO.value
     if isinstance(valuetype, lltype.Ptr):
         entry.value = lltype.nullptr(valuetype.TO)
-    d.num_items -= 1
     num_entries = len(d.entries)
-    if num_entries > STRDICT_INITSIZE and d.num_items < num_entries / 4: 
-        ll_strdict_resize(d) 
+    if num_entries > DICT_INITSIZE and d.num_items < num_entries / 4:
+        ll_dict_resize(d, dictrepr)
 
-def ll_strdict_resize(d):
+def ll_dict_resize(d, dictrepr):
     old_entries = d.entries
     old_size = len(old_entries) 
     # make a 'new_size' estimate and shrink it if there are many
     # deleted entry markers
     new_size = old_size * 2
-    while new_size > STRDICT_INITSIZE and d.num_items < new_size / 4:
+    while new_size > DICT_INITSIZE and d.num_items < new_size / 4:
         new_size /= 2
     d.entries = lltype.malloc(lltype.typeOf(old_entries).TO, new_size)
     d.num_pristine_entries = new_size - d.num_items
     i = 0
     while i < old_size:
         entry = old_entries[i]
-        if entry.key and entry.key != deleted_entry_marker:
-           new_entry = ll_strdict_lookup(d, entry.key)
+        if entry.valid:
+           new_entry = ll_dict_lookup(d, entry.key, dictrepr)
            new_entry.key = entry.key
            new_entry.value = entry.value
+           new_entry.valid = True
+           new_entry.everused = True
         i += 1
 
-# the below is a port of CPython's dictobject.c's lookdict implementation 
+# ------- a port of CPython's dictobject.c's lookdict implementation -------
 PERTURB_SHIFT = 5
 
-def ll_strdict_lookup(d, key): 
-    hash = rstr.ll_strhash(key) 
+def ll_dict_lookup(d, key, dictrepr):
+    hash = dictrepr.ll_keyhash(key)
     entries = d.entries
     mask = len(entries) - 1
     i = r_uint(hash & mask) 
 
+    """XXX MUTATION PROTECTION!"""
+
     # do the first try before any looping 
     entry = entries[i]
-    if not entry.key or entry.key == key: 
-        return entry 
-    if entry.key == deleted_entry_marker: 
-        freeslot = entry 
-    else: 
-        if entry.key.hash == hash and rstr.ll_streq(entry.key, key): 
-            return entry 
+    if entry.valid:
+        if entry.key == key:
+            return entry   # found the entry
+        if dictrepr.ll_keyeq is not None and dictrepr.ll_keyeq(entry.key, key):
+            return entry   # found the entry
         freeslot = lltype.nullptr(lltype.typeOf(entry).TO)
+    elif entry.everused:
+        freeslot = entry
+    else:
+        return entry    # pristine entry -- lookup failed
 
-    # In the loop, key == deleted_entry_marker is by far (factor of 100s) the
-    # least likely outcome, so test for that last.  
+    # In the loop, a deleted entry (everused and not valid) is by far
+    # (factor of 100s) the least likely outcome, so test for that last.
     perturb = r_uint(hash) 
     while 1: 
-        i = (i << 2) + i + perturb + 1
-        entry = entries[i & mask]
-        if not entry.key: 
+        i = ((i << 2) + i + perturb + 1) & mask
+        entry = entries[i]
+        if not entry.everused:
             return freeslot or entry 
-        if entry.key == key or (entry.key.hash == hash and 
-                                entry.key != deleted_entry_marker and
-                                rstr.ll_streq(entry.key, key)): 
-            return entry
-        if entry.key == deleted_entry_marker and not freeslot:
+        elif entry.valid:
+            if entry.key == key:
+                return entry
+            if dictrepr.ll_keyeq is not None and dictrepr.ll_keyeq(entry.key, key):
+                return entry
+        elif not freeslot:
             freeslot = entry 
         perturb >>= PERTURB_SHIFT
 
 # ____________________________________________________________
 #
 #  Irregular operations.
-STRDICT_INITSIZE = 8
+DICT_INITSIZE = 8
 
-def ll_newstrdict(DICTPTR):
+def ll_newdict(DICTPTR):
     d = lltype.malloc(DICTPTR.TO)
-    d.entries = lltype.malloc(DICTPTR.TO.entries.TO, STRDICT_INITSIZE)
+    d.entries = lltype.malloc(DICTPTR.TO.entries.TO, DICT_INITSIZE)
     d.num_items = 0  # but still be explicit
-    d.num_pristine_entries = STRDICT_INITSIZE 
+    d.num_pristine_entries = DICT_INITSIZE 
     return d
 
 def rtype_newdict(hop):
@@ -334,31 +355,27 @@
     if r_dict == robject.pyobj_repr: # special case: SomeObject: SomeObject dicts!
         cdict = hop.inputconst(robject.pyobj_repr, dict)
         return hop.genop('simple_call', [cdict], resulttype = robject.pyobj_repr)
-    if r_dict == remptydict.emptydict_repr: # other special case: empty dicts
-        return hop.inputconst(lltype.Void, {})
-    if not isinstance(r_dict, StrDictRepr):
-        raise rmodel.TyperError("cannot create non-StrDicts, got %r" %(r_dict,))
     c1 = hop.inputconst(lltype.Void, r_dict.lowleveltype)
-    v_result = hop.gendirectcall(ll_newstrdict, c1) 
+    v_result = hop.gendirectcall(ll_newdict, c1)
     return v_result
 
 # ____________________________________________________________
 #
 #  Iteration.
 
-class StrDictIteratorRepr(rmodel.IteratorRepr):
+class DictIteratorRepr(rmodel.IteratorRepr):
 
     def __init__(self, r_dict, variant="keys"):
         self.r_dict = r_dict
         self.variant = variant
-        self.lowleveltype = lltype.Ptr(lltype.GcStruct('strdictiter',
+        self.lowleveltype = lltype.Ptr(lltype.GcStruct('dictiter',
                                          ('dict', r_dict.lowleveltype),
                                          ('index', lltype.Signed)))
 
     def newiter(self, hop):
         v_dict, = hop.inputargs(self.r_dict)
         citerptr = hop.inputconst(lltype.Void, self.lowleveltype)
-        return hop.gendirectcall(ll_strdictiter, citerptr, v_dict)
+        return hop.gendirectcall(ll_dictiter, citerptr, v_dict)
 
     def rtype_next(self, hop):
         v_iter, = hop.inputargs(self)
@@ -367,31 +384,30 @@
         c1 = hop.inputconst(lltype.Void, r_list.lowleveltype)
         hop.has_implicit_exception(StopIteration) # record that we know about it
         hop.exception_is_here()
-        return hop.gendirectcall(ll_strdictnext, v_iter, v_func, c1)
+        return hop.gendirectcall(ll_dictnext, v_iter, v_func, c1)
 
-def ll_strdictiter(ITERPTR, d):
+def ll_dictiter(ITERPTR, d):
     iter = lltype.malloc(ITERPTR.TO)
     iter.dict = d
     iter.index = 0
     return iter
 
-def ll_strdictnext(iter, func, RETURNTYPE):
+def ll_dictnext(iter, func, RETURNTYPE):
     entries = iter.dict.entries
     index = iter.index
     entries_len = len(entries)
     while index < entries_len:
         entry = entries[index]
-        key = entry.key
         index = index + 1
-        if key and key != deleted_entry_marker:
+        if entry.valid:
             iter.index = index
             if func is dum_items:
                 r = lltype.malloc(RETURNTYPE.TO)
-                r.item0 = key
+                r.item0 = entry.key
                 r.item1 = entry.value
                 return r
             elif func is dum_keys:
-                return key
+                return entry.key
             elif func is dum_values:
                 return entry.value
     iter.index = index
@@ -400,9 +416,9 @@
 # _____________________________________________________________
 # methods
 
-def ll_get(dict, key, default):
-    entry = ll_strdict_lookup(dict, key) 
-    if entry.key and entry.key != deleted_entry_marker: 
+def ll_get(dict, key, default, dictrepr):
+    entry = ll_dict_lookup(dict, key, dictrepr) 
+    if entry.valid:
         return entry.value
     else: 
         return default
@@ -420,25 +436,28 @@
         entry = dict.entries[i]
         d_entry.key = entry.key
         d_entry.value = entry.value
+        d_entry.valid = entry.valid
+        d_entry.everused = entry.everused
         i += 1
     return d
 
 def ll_clear(d):
-    if len(d.entries) == d.num_pristine_entries == STRDICT_INITSIZE:
+    if len(d.entries) == d.num_pristine_entries == DICT_INITSIZE:
         return
     DICTPTR = lltype.typeOf(d)
-    d.entries = lltype.malloc(DICTPTR.TO.entries.TO, STRDICT_INITSIZE)
+    d.entries = lltype.malloc(DICTPTR.TO.entries.TO, DICT_INITSIZE)
     d.num_items = 0
-    d.num_pristine_entries = STRDICT_INITSIZE 
+    d.num_pristine_entries = DICT_INITSIZE
 
-def ll_update(dic1, dic2):
-    d2len =len(dic2.entries)
+def ll_update(dic1, dic2, dictrepr):
+    # XXX warning, no protection against ll_dict_setitem mutating dic2
+    d2len = len(dic2.entries)
     entries = dic2.entries
     i = 0
     while i < d2len:
         entry = entries[i]
-        if entry.key and entry.key != deleted_entry_marker:
-            ll_strdict_setitem(dic1, entry.key, entry.value)
+        if entry.valid:
+            ll_dict_setitem(dic1, entry.key, entry.value, dictrepr)
         i += 1
 
 # this is an implementation of keys(), values() and items()
@@ -455,23 +474,20 @@
     p = 0
     while i < dlen:
         entry = entries[i]
-        key = entry.key
-        if key and key != deleted_entry_marker:
+        if entry.valid:
             if func is dum_items:
                 r = lltype.malloc(LISTPTR.TO.items.TO.OF.TO)
-                r.item0 = key
+                r.item0 = entry.key
                 r.item1 = entry.value
                 items[p] = r
             elif func is dum_keys:
-                items[p] = key
+                items[p] = entry.key
             elif func is dum_values:
                 items[p] = entry.value
             p += 1
         i += 1
     return res
 
-def ll_contains(d, key): 
-    entry = ll_strdict_lookup(d, key) 
-    if entry.key and entry.key != deleted_entry_marker: 
-        return True
-    return False
+def ll_contains(d, key, dictrepr):
+    entry = ll_dict_lookup(d, key, dictrepr)
+    return entry.valid

Modified: pypy/dist/pypy/rpython/rint.py
==============================================================================
--- pypy/dist/pypy/rpython/rint.py	(original)
+++ pypy/dist/pypy/rpython/rint.py	Sat Sep 10 14:40:02 2005
@@ -209,6 +209,9 @@
     def get_ll_eq_function(self):
         return None 
 
+    def get_ll_hash_function(self):
+        return ll_hash_int
+
     def rtype_chr(_, hop):
         vlist =  hop.inputargs(Signed)
         return hop.genop('cast_int_to_char', vlist, resulttype=Char)
@@ -399,6 +402,9 @@
         j += 1
     return result
 
+def ll_hash_int(n):
+    return n
+
 #
 # _________________________ Conversions _________________________
 

Modified: pypy/dist/pypy/rpython/rmodel.py
==============================================================================
--- pypy/dist/pypy/rpython/rmodel.py	(original)
+++ pypy/dist/pypy/rpython/rmodel.py	Sat Sep 10 14:40:02 2005
@@ -3,7 +3,7 @@
 from pypy.objspace.flow.model import Constant
 from pypy.rpython.lltype import Void, Bool, Float, Signed, Char, UniChar
 from pypy.rpython.lltype import typeOf, LowLevelType, Ptr, PyObject
-from pypy.rpython.lltype import FuncType, functionptr
+from pypy.rpython.lltype import FuncType, functionptr, cast_ptr_to_int
 from pypy.tool.ansi_print import ansi_print
 from pypy.rpython.error import TyperError, MissingRTypeOperation 
 
@@ -101,6 +101,12 @@
     def get_ll_eq_function(self): 
         raise TyperError, 'no equality function for %r' % self
 
+    def get_ll_hash_function(self):
+        if not isinstance(self.lowleveltype, Ptr):
+            raise TyperError, 'no hashing function for %r' % self
+        # default behavior: use the pointer identity as a hash
+        return ll_hash_ptr
+
     def rtype_bltn_list(self, hop):
         raise TyperError, 'no list() support for %r' % self
 
@@ -152,6 +158,13 @@
     def make_iterator_repr(self, *variant):
         raise TyperError("%s is not iterable" % (self,))
 
+def ll_hash_ptr(p):
+    return cast_ptr_to_int(p)
+
+def ll_hash_void(v):
+    return 0
+
+
 class IteratorRepr(Repr):
     """Base class of Reprs of any kind of iterator."""
 
@@ -251,6 +264,8 @@
 
 class VoidRepr(Repr):
     lowleveltype = Void
+    def get_ll_eq_function(self): return None
+    def get_ll_hash_function(self): return ll_hash_void
 impossible_repr = VoidRepr()
 
 # ____________________________________________________________

Modified: pypy/dist/pypy/rpython/rstr.py
==============================================================================
--- pypy/dist/pypy/rpython/rstr.py	(original)
+++ pypy/dist/pypy/rpython/rstr.py	Sat Sep 10 14:40:02 2005
@@ -76,6 +76,9 @@
     def get_ll_eq_function(self):
         return ll_streq
 
+    def get_ll_hash_function(self):
+        return ll_strhash
+
     def rtype_len(_, hop):
         v_str, = hop.inputargs(string_repr)
         return hop.gendirectcall(ll_strlen, v_str)
@@ -393,6 +396,9 @@
     def get_ll_eq_function(self):
         return None 
 
+    def get_ll_hash_function(self):
+        return ll_char_hash
+
     def rtype_len(_, hop):
         return hop.inputconst(Signed, 1)
 
@@ -446,6 +452,9 @@
     def get_ll_eq_function(self):
         return None 
 
+    def get_ll_hash_function(self):
+        return ll_unichar_hash
+
 ##    def rtype_len(_, hop):
 ##        return hop.inputconst(Signed, 1)
 ##
@@ -536,6 +545,12 @@
         j += 1
     return newstr
 
+def ll_char_hash(ch):
+    return ord(ch)
+
+def ll_unichar_hash(ch):
+    return ord(ch)
+
 def ll_strlen(s):
     return len(s.chars)
 

Modified: pypy/dist/pypy/rpython/test/test_rdict.py
==============================================================================
--- pypy/dist/pypy/rpython/test/test_rdict.py	(original)
+++ pypy/dist/pypy/rpython/test/test_rdict.py	Sat Sep 10 14:40:02 2005
@@ -90,7 +90,7 @@
         return d[c2]
 
     char_by_hash = {}
-    base = rdict.STRDICT_INITSIZE
+    base = rdict.DICT_INITSIZE
     for y in range(0, 256):
         y = chr(y)
         y_hash = lowlevelhash(y) % base 
@@ -113,7 +113,7 @@
 
     res = interpret(func2, [ord(x), ord(y)])
     for i in range(len(res.entries)): 
-        assert res.entries[i].key != rdict.deleted_entry_marker
+        assert not (res.entries[i].everused and not res.entries[i].valid)
 
     def func3(c0, c1, c2, c3, c4, c5, c6, c7):
         d = {}
@@ -127,29 +127,29 @@
         c7 = chr(c7) ; d[c7] = 1; del d[c7]
         return d
 
-    if rdict.STRDICT_INITSIZE != 8: 
+    if rdict.DICT_INITSIZE != 8: 
         py.test.skip("make dict tests more indepdent from initsize")
     res = interpret(func3, [ord(char_by_hash[i][0]) 
-                               for i in range(rdict.STRDICT_INITSIZE)])
+                               for i in range(rdict.DICT_INITSIZE)])
     count_frees = 0
     for i in range(len(res.entries)):
-        if not res.entries[i].key:
+        if not res.entries[i].everused:
             count_frees += 1
     assert count_frees >= 3
 
 def test_dict_resize():
     def func(want_empty):
         d = {}
-        for i in range(rdict.STRDICT_INITSIZE):
+        for i in range(rdict.DICT_INITSIZE):
             d[chr(ord('a') + i)] = i
         if want_empty:
-            for i in range(rdict.STRDICT_INITSIZE):
+            for i in range(rdict.DICT_INITSIZE):
                 del d[chr(ord('a') + i)]
         return d
     res = interpret(func, [0])
-    assert len(res.entries) > rdict.STRDICT_INITSIZE 
+    assert len(res.entries) > rdict.DICT_INITSIZE 
     res = interpret(func, [1])
-    assert len(res.entries) == rdict.STRDICT_INITSIZE 
+    assert len(res.entries) == rdict.DICT_INITSIZE 
 
 def test_dict_iteration():
     def func(i, j):
@@ -300,3 +300,19 @@
     assert res is True
     res = interpret(func, [42])
     assert res is True
+
+def test_int_dict():
+    def func(a, b):
+        dic = {12: 34}
+        dic[a] = 1000
+        return dic.get(b, -123)
+    res = interpret(func, [12, 12])
+    assert res == 1000
+    res = interpret(func, [12, 13])
+    assert res == -123
+    res = interpret(func, [524, 12])
+    assert res == 34
+    res = interpret(func, [524, 524])
+    assert res == 1000
+    res = interpret(func, [524, 1036])
+    assert res == -123



More information about the Pypy-commit mailing list