[pypy-svn] r39341 - in pypy/dist/pypy/rpython: . test

ludal at codespeak.net ludal at codespeak.net
Fri Feb 23 15:32:26 CET 2007


Author: ludal
Date: Fri Feb 23 15:32:25 2007
New Revision: 39341

Modified:
   pypy/dist/pypy/rpython/rfloat.py
   pypy/dist/pypy/rpython/rint.py
   pypy/dist/pypy/rpython/rtuple.py
   pypy/dist/pypy/rpython/test/test_rtuple.py
Log:
provides rtuple comparison.
works for mixes of ints and floats, and probably every type that provides the
necessary comparison operators


Modified: pypy/dist/pypy/rpython/rfloat.py
==============================================================================
--- pypy/dist/pypy/rpython/rfloat.py	(original)
+++ pypy/dist/pypy/rpython/rfloat.py	Fri Feb 23 15:32:25 2007
@@ -110,6 +110,10 @@
 
     def get_ll_eq_function(self):
         return None
+    get_ll_gt_function = get_ll_eq_function
+    get_ll_lt_function = get_ll_eq_function
+    get_ll_ge_function = get_ll_eq_function
+    get_ll_le_function = get_ll_eq_function
 
     def get_ll_hash_function(self):
         return ll_hash_float

Modified: pypy/dist/pypy/rpython/rint.py
==============================================================================
--- pypy/dist/pypy/rpython/rint.py	(original)
+++ pypy/dist/pypy/rpython/rint.py	Fri Feb 23 15:32:25 2007
@@ -222,6 +222,13 @@
         raise TyperError("not an integer: %r" % (value,))
 
     def get_ll_eq_function(self):
+        return None
+    get_ll_gt_function = get_ll_eq_function
+    get_ll_lt_function = get_ll_eq_function
+    get_ll_ge_function = get_ll_eq_function
+    get_ll_le_function = get_ll_eq_function
+
+    def get_ll_ge_function(self):
         return None 
 
     def get_ll_hash_function(self):

Modified: pypy/dist/pypy/rpython/rtuple.py
==============================================================================
--- pypy/dist/pypy/rpython/rtuple.py	(original)
+++ pypy/dist/pypy/rpython/rtuple.py	Fri Feb 23 15:32:25 2007
@@ -22,6 +22,7 @@
 
 
 _gen_eq_function_cache = {}
+_gen_cmp_function_cache = {}
 _gen_hash_function_cache = {}
 _gen_str_function_cache = {}
 
@@ -46,6 +47,49 @@
 
         _gen_eq_function_cache[key] = ll_eq
         return ll_eq
+import os
+def gen_cmp_function(items_r, op_funcs, eq_funcs, strict):
+    """generates <= and >= comparison ll_op for tuples
+    cmp_funcs is a tuple of (strict_comp, equality) functions
+    works for != with strict==True
+    """
+    cmp_funcs = zip(op_funcs,eq_funcs)
+    autounrolling_funclist = unrolling_iterable(enumerate(cmp_funcs))
+    key = tuple(cmp_funcs), strict
+    try:
+        return _gen_cmp_function_cache[key]
+    except KeyError:
+        def ll_cmp(t1, t2):
+            cmp_res = True
+            for i, (cmpfn, eqfn) in autounrolling_funclist:
+                attrname = 'item%d' % i
+                item1 = getattr(t1, attrname)
+                item2 = getattr(t2, attrname)
+                cmp_res = cmpfn(item1, item2)
+                if cmp_res:
+                    # a strict compare is true we shortcut
+                    return True
+                eq_res = eqfn(item1, item2)
+                if not eq_res:
+                    # not strict and not equal we fail
+                    return False
+            # Everything's equal here
+            if strict:
+                return False
+            else:
+                return True
+        _gen_cmp_function_cache[key] = ll_cmp
+        return ll_cmp
+
+def gen_gt_function(items_r, strict):
+    gt_funcs = [r_item.get_ll_gt_function() or operator.gt for r_item in items_r]
+    eq_funcs = [r_item.get_ll_eq_function() or operator.eq for r_item in items_r]
+    return gen_cmp_function( items_r, gt_funcs, eq_funcs, strict )
+
+def gen_lt_function(items_r, strict):
+    lt_funcs = [r_item.get_ll_lt_function() or operator.lt for r_item in items_r]
+    eq_funcs = [r_item.get_ll_eq_function() or operator.eq for r_item in items_r]
+    return gen_cmp_function( items_r, lt_funcs, eq_funcs, strict )
 
 def gen_hash_function(items_r):
     # based on CPython
@@ -166,8 +210,20 @@
     def get_ll_eq_function(self):
         return gen_eq_function(self.items_r)
 
+    def get_ll_ge_function(self):
+        return gen_gt_function(self.items_r, False)
+
+    def get_ll_gt_function(self):
+        return gen_gt_function(self.items_r, True)
+
+    def get_ll_le_function(self):
+        return gen_lt_function(self.items_r, False)
+
+    def get_ll_lt_function(self):
+        return gen_lt_function(self.items_r, True)
+
     def get_ll_hash_function(self):
-        return gen_hash_function(self.items_r)    
+        return gen_hash_function(self.items_r)
 
     ll_str = property(gen_str_function)
 
@@ -241,6 +297,30 @@
         ll_eq = r_tup1.get_ll_eq_function()
         return hop.gendirectcall(ll_eq, v_tuple1, v_tuple2)
 
+    def rtype_ge((r_tup1, r_tup2), hop):
+        # XXX assumes that r_tup2 is convertible to r_tup1
+        v_tuple1, v_tuple2 = hop.inputargs(r_tup1, r_tup1)
+        ll_ge = r_tup1.get_ll_ge_function()
+        return hop.gendirectcall(ll_ge, v_tuple1, v_tuple2)
+
+    def rtype_gt((r_tup1, r_tup2), hop):
+        # XXX assumes that r_tup2 is convertible to r_tup1
+        v_tuple1, v_tuple2 = hop.inputargs(r_tup1, r_tup1)
+        ll_gt = r_tup1.get_ll_gt_function()
+        return hop.gendirectcall(ll_gt, v_tuple1, v_tuple2)
+
+    def rtype_le((r_tup1, r_tup2), hop):
+        # XXX assumes that r_tup2 is convertible to r_tup1
+        v_tuple1, v_tuple2 = hop.inputargs(r_tup1, r_tup1)
+        ll_le = r_tup1.get_ll_le_function()
+        return hop.gendirectcall(ll_le, v_tuple1, v_tuple2)
+
+    def rtype_lt((r_tup1, r_tup2), hop):
+        # XXX assumes that r_tup2 is convertible to r_tup1
+        v_tuple1, v_tuple2 = hop.inputargs(r_tup1, r_tup1)
+        ll_lt = r_tup1.get_ll_lt_function()
+        return hop.gendirectcall(ll_lt, v_tuple1, v_tuple2)
+
     def rtype_ne(tup1tup2, hop):
         v_res = tup1tup2.rtype_eq(hop)
         return hop.genop('bool_not', [v_res], resulttype=Bool)

Modified: pypy/dist/pypy/rpython/test/test_rtuple.py
==============================================================================
--- pypy/dist/pypy/rpython/test/test_rtuple.py	(original)
+++ pypy/dist/pypy/rpython/test/test_rtuple.py	Fri Feb 23 15:32:25 2007
@@ -280,6 +280,51 @@
         res = self.interpret(f, [2])
         assert res is True
 
+    TUPLES = [
+        ((1,2),  (2,3),   -1),
+        ((1,2),  (1,3),   -1),
+        ((1,2),  (1,1),    1),
+        ((1,2),  (1,2),    0),
+        ((1.,2.),(2.,3.), -1),
+        ((1.,2.),(1.,3.), -1),
+        ((1.,2.),(1.,1.),  1),
+        ((1.,2.),(1.,2.),  0),
+        ((1,2.),(2,3.), -1),
+        ((1,2.),(1,3.), -1),
+        ((1,2.),(1,1.),  1),
+        ((1,2.),(1,2.),  0),
+##         ((1,"def"),(1,"abc"), -1),
+##         ((1.,"abc"),(1.,"abc"), 0),
+        ]
+
+    def test_tuple_comparison(self):
+        def f_lt( a, b, c, d ):
+            return (a,b) < (c,d)
+        def f_le( a, b, c, d ):
+            return (a,b) <= (c,d)
+        def f_gt( a, b, c, d ):
+            return (a,b) > (c,d)
+        def f_ge( a, b, c, d ):
+            return (a,b) >= (c,d)
+        def test_lt( a,b,c,d,resu ):
+            res = self.interpret(f_lt,[a,b,c,d])
+            assert res == (resu == -1), "Error (%s,%s)<(%s,%s) is %s(%s)" % (a,b,c,d,res,resu)
+        def test_le( a,b,c,d,resu ):
+            res = self.interpret(f_le,[a,b,c,d])
+            assert res == (resu <= 0), "Error (%s,%s)<=(%s,%s) is %s(%s)" % (a,b,c,d,res,resu)
+        def test_gt( a,b,c,d,resu ):
+            res = self.interpret(f_gt,[a,b,c,d])
+            assert res == ( resu == 1 ), "Error (%s,%s)>(%s,%s) is %s(%s)" % (a,b,c,d,res,resu)
+        def test_ge( a,b,c,d,resu ):
+            res = self.interpret(f_ge,[a,b,c,d])
+            assert res == ( resu >= 0 ), "Error (%s,%s)>=(%s,%s) is %s(%s)" % (a,b,c,d,res,resu)
+
+        for (a,b),(c,d),resu in self.TUPLES:
+            yield test_lt, a,b,c,d, resu
+            yield test_gt, a,b,c,d, resu
+            yield test_le, a,b,c,d, resu
+            yield test_ge, a,b,c,d, resu
+
     def test_tuple_hash(self):
         def f(n):
             return hash((n, 6)) == hash((3, n*2))



More information about the Pypy-commit mailing list