[pypy-svn] r79685 - pypy/branch/jitypes2/pypy/module/pypyjit/test

antocuni at codespeak.net antocuni at codespeak.net
Tue Nov 30 15:56:24 CET 2010


Author: antocuni
Date: Tue Nov 30 15:56:22 2010
New Revision: 79685

Modified:
   pypy/branch/jitypes2/pypy/module/pypyjit/test/test_pypy_c.py
Log:
write a test to check that a call through ctypes is fast, and also generally improve a bit the machinery in test_pypy_c


Modified: pypy/branch/jitypes2/pypy/module/pypyjit/test/test_pypy_c.py
==============================================================================
--- pypy/branch/jitypes2/pypy/module/pypyjit/test/test_pypy_c.py	(original)
+++ pypy/branch/jitypes2/pypy/module/pypyjit/test/test_pypy_c.py	Tue Nov 30 15:56:22 2010
@@ -10,9 +10,9 @@
                     if op.getopname().startswith(prefix)]
 
     def __repr__(self):
-        return "%s%s" % (self.bytecode, list.__repr__(self))
+        return "%s%s" % (self.opcode, list.__repr__(self))
 
-ZERO_OP_BYTECODES = [
+ZERO_OP_OPCODES = [
     'POP_TOP',
     'ROT_TWO',
     'ROT_THREE',
@@ -82,11 +82,13 @@
     def run_source(self, source, expected_max_ops, *testcases, **kwds):
         assert isinstance(expected_max_ops, int)
         threshold = kwds.pop('threshold', 3)
+        filter_loops = kwds.pop('filter_loops', False) # keep only the loops beginning from case%d.py
         if kwds:
             raise TypeError, 'Unsupported keyword arguments: %s' % kwds.keys()
         source = py.code.Source(source)
         filepath = self.tmpdir.join('case%d.py' % self.counter)
         logfilepath = filepath.new(ext='.log')
+        self.logfilepath = logfilepath
         self.__class__.counter += 1
         f = filepath.open('w')
         print >> f, source
@@ -125,7 +127,7 @@
         if result.strip().startswith('SKIP:'):
             py.test.skip(result.strip())
         assert result.splitlines()[-1].strip() == 'OK :-)'
-        self.parse_loops(logfilepath)
+        self.parse_loops(logfilepath, filepath, filter_loops)
         self.print_loops()
         print logfilepath
         if self.total_ops > expected_max_ops:
@@ -133,7 +135,7 @@
                 self.total_ops, expected_max_ops)
         return result
 
-    def parse_loops(self, opslogfile):
+    def parse_loops(self, opslogfile, filepath, filter_loops):
         from pypy.jit.tool.oparser import parse
         from pypy.tool import logparser
         assert opslogfile.check()
@@ -143,27 +145,49 @@
                          if not from_entry_bridge(part, parts)]
         # skip entry bridges, they can contain random things
         self.loops = [parse(part, no_namespace=True) for part in self.rawloops]
-        self.sliced_loops = [] # contains all bytecodes of all loops
+        if filter_loops:
+            self.loops = self.filter_loops(filepath, self.loops)
+        self.all_bytecodes = []    # contains all bytecodes of all loops
+        self.bytecode_by_loop = {} # contains all bytecodes divided by loops
         self.total_ops = 0
         for loop in self.loops:
+            loop_bytecodes = []
+            self.bytecode_by_loop[loop] = loop_bytecodes
             self.total_ops += len(loop.operations)
             for op in loop.operations:
                 if op.getopname() == "debug_merge_point":
-                    sliced_loop = BytecodeTrace()
-                    sliced_loop.bytecode = op.getarg(0)._get_str().rsplit(" ", 1)[1]
-                    self.sliced_loops.append(sliced_loop)
+                    bytecode = BytecodeTrace()
+                    bytecode.opcode = op.getarg(0)._get_str().rsplit(" ", 1)[1]
+                    bytecode.debug_merge_point = op
+                    loop_bytecodes.append(bytecode)
+                    self.all_bytecodes.append(bytecode)
                 else:
-                    sliced_loop.append(op)
+                    bytecode.append(op)
         self.check_0_op_bytecodes()
 
+    def filter_loops(self, filepath, loops):
+        newloops = []
+        for loop in loops:
+            op = loop.operations[0]
+            # if the first op is not debug_merge_point, it's a bridge: for
+            # now, we always include them
+            if (op.getopname() != 'debug_merge_point' or 
+                str(filepath) in str(op.getarg(0))):
+                newloops.append(loop)
+        return newloops
+
     def check_0_op_bytecodes(self):
-        for bytecodetrace in self.sliced_loops:
-            if bytecodetrace.bytecode not in ZERO_OP_BYTECODES:
+        for bytecodetrace in self.all_bytecodes:
+            if bytecodetrace.opcode not in ZERO_OP_OPCODES:
                 continue
             assert not bytecodetrace
 
-    def get_by_bytecode(self, name):
-        return [ops for ops in self.sliced_loops if ops.bytecode == name]
+    def get_by_bytecode(self, name, loop=None):
+        if loop:
+            bytecodes = self.bytecode_by_loop[loop]
+        else:
+            bytecodes = self.all_bytecodes
+        return [ops for ops in bytecodes if ops.opcode == name]
 
     def print_loops(self):
         for rawloop in self.rawloops:
@@ -1232,6 +1256,54 @@
         assert call.getarg(1).value == 2.0
         assert call.getarg(2).value == 3.0
 
+    def test_ctypes_call(self):
+        from pypy.rlib.test.test_libffi import get_libm_name
+        libm_name = get_libm_name(sys.platform)
+        out = self.run_source('''
+        def main():
+            import ctypes
+            libm = ctypes.CDLL('%(libm_name)s')
+            fabs = libm.fabs
+            fabs.argtypes = [ctypes.c_double]
+            fabs.restype = ctypes.c_double
+            x = -4
+            for i in range(2000):
+                x = x + 0      # convince the perfect spec. to make x virtual
+                x = fabs(x)
+                x = x - 100
+            print fabs._ptr.getaddr()
+            return x
+        ''' % locals(),
+                              10000, ([], -4.0),
+                              threshold=1000,
+                              filter_loops=True)
+        fabs_addr = int(out.splitlines()[0])
+        assert len(self.loops) == 2 # the first is the loop, the second is a bridge
+        loop = self.loops[0]
+        call_functions = self.get_by_bytecode('CALL_FUNCTION', loop)
+        assert len(call_functions) == 2
+        #
+        # this is the call "fabs(x)"
+        call_main = call_functions[0]
+        assert 'code object main' in str(call_main.debug_merge_point)
+        assert call_main.get_opnames('call') == ['call'] # this is call(getexecutioncontext)
+        #
+        # this is the ffi call inside ctypes
+        call_ffi = call_functions[1]
+        last_ops = [op.getopname() for op in call_ffi[-6:]]
+        assert last_ops == ['force_token',
+                            'setfield_gc',         # framestackdepth
+                            'setfield_gc',         # vable_token
+                            'call_may_force',
+                            'guard_not_forced',
+                            'guard_no_exception']
+        call = call_ffi[-3]
+        assert call.getarg(0).value == fabs_addr
+        #
+        # finally, check that we don't force anything
+        for op in loop.operations:
+            assert op.getopname() != 'new_with_vtable'
+            
     # test_circular
 
 class AppTestJIT(PyPyCJITTests):



More information about the Pypy-commit mailing list