[pypy-svn] r48901 - in pypy/branch/new-genc-tests-wrapper/pypy/translator: c c/test llsupport

rxe at codespeak.net rxe at codespeak.net
Wed Nov 21 14:53:03 CET 2007


Author: rxe
Date: Wed Nov 21 14:53:02 2007
New Revision: 48901

Modified:
   pypy/branch/new-genc-tests-wrapper/pypy/translator/c/genc.py
   pypy/branch/new-genc-tests-wrapper/pypy/translator/c/test/test_modwrapper.py
   pypy/branch/new-genc-tests-wrapper/pypy/translator/llsupport/modwrapper.py
Log:
(pedronis, rxe) add a way to augment the entrypoint for mod wrapping with a mixed level annotator.  this approach should be extended for returning values also.

Modified: pypy/branch/new-genc-tests-wrapper/pypy/translator/c/genc.py
==============================================================================
--- pypy/branch/new-genc-tests-wrapper/pypy/translator/c/genc.py	(original)
+++ pypy/branch/new-genc-tests-wrapper/pypy/translator/c/genc.py	Wed Nov 21 14:53:02 2007
@@ -36,6 +36,9 @@
             libraries = []
         self.libraries = libraries
 
+    def augment_entrypoint(self):
+        pass
+
     def build_database(self, pyobj_options=None):
         translator = self.translator
 
@@ -53,6 +56,8 @@
         else:
             stacklesstransformer = None
 
+        self.augment_entrypoint()
+
         db = LowLevelDatabase(translator, standalone=self.standalone,
                               gcpolicyclass=gcpolicyclass,
                               stacklesstransformer=stacklesstransformer,
@@ -108,7 +113,7 @@
 
         if db is None:
             db = self.build_database()
-            self.db = db
+        self.db = db
         pf = self.getentrypointptr()
         pfname = db.get(pf)
 
@@ -171,6 +176,10 @@
         self.graph_entrypoint = bk.getdesc(self.entrypoint).getuniquegraph()
         return getfunctionptr(self.graph_entrypoint)
 
+    def augment_entrypoint(self):
+        from pypy.translator.llsupport import modwrapper
+        self.entrypoint = modwrapper.augment_entrypoint(self.translator,
+                                                        self.entrypoint)
     def compile(self):
         assert self.c_source_filename 
         assert not self._compiled

Modified: pypy/branch/new-genc-tests-wrapper/pypy/translator/c/test/test_modwrapper.py
==============================================================================
--- pypy/branch/new-genc-tests-wrapper/pypy/translator/c/test/test_modwrapper.py	(original)
+++ pypy/branch/new-genc-tests-wrapper/pypy/translator/c/test/test_modwrapper.py	Wed Nov 21 14:53:02 2007
@@ -81,6 +81,12 @@
         assert fn(2, 1) == 44
         assert fn(6, 2) == 90
 
+    def test_argument_string(self):
+        def fn(s):
+            return len(s)
+        fn = self.getcompiled(fn, [str])
+        assert fn('aaa') == 3
+
 
 
 class TestWrapperRefcounting(CompilationTestCase, WrapperTests):

Modified: pypy/branch/new-genc-tests-wrapper/pypy/translator/llsupport/modwrapper.py
==============================================================================
--- pypy/branch/new-genc-tests-wrapper/pypy/translator/llsupport/modwrapper.py	(original)
+++ pypy/branch/new-genc-tests-wrapper/pypy/translator/llsupport/modwrapper.py	Wed Nov 21 14:53:02 2007
@@ -1,11 +1,15 @@
 " THIS IS ONLY FOR TESTING "
 
 import py
+import inspect
 import ctypes
 from pypy.rpython.lltypesystem import lltype 
 from pypy.rpython.lltypesystem import llmemory
 from pypy.rlib.rarithmetic import r_uint, r_longlong, r_ulonglong
 from pypy.rpython.lltypesystem.rstr import STR
+from pypy.rpython.annlowlevel import MixLevelHelperAnnotator
+from pypy.annotation import model as annmodel
+from pypy.rpython.lltypesystem.rffi import charp2str, CCHARP
 
 class CtypesModule:
     """ use ctypes to create a temporary module """
@@ -58,19 +62,7 @@
     return res
 
 def from_str(arg):
-    class Chars(ctypes.Structure):
-        _fields_ = [("size", ctypes.c_int),
-                    ("data", ctypes.c_byte * len(arg))]
-    class STR(ctypes.Structure):
-        _fields_ = get_gc_header() + [
-                    ("hash", ctypes.c_int),
-                    ("chars", Chars)]
-    s = STR()
-    s.hash = 0
-    s.chars.size = len(arg)
-    for ii in range(len(arg)):
-        s.chars.data[ii] = ord(arg[ii])
-    return ctypes.addressof(s)
+    return arg
 
 def to_r_uint(res):
     return {'type':'r_uint', 'value':long(res)}
@@ -123,7 +115,7 @@
 def array_to_list(res, C_TYPE, action, size=-1):
     if res:
         if size == -1:
-            size = ctypes.cast(res, ctypes.POINTER(ctypes.c_int)).contents.value
+            size = ctypes.cast(res + gc_header_offset(), ctypes.POINTER(ctypes.c_int)).contents.value
         class Array(ctypes.Structure):
             _fields_ = get_gc_header() + [
                         ("size", ctypes.c_int),
@@ -139,6 +131,12 @@
     size = ctypes.cast(addr_str, ctypes.POINTER(ctypes.c_int)).contents.value - 1
     name = ctypes.string_at(addr_str+4, size)
     return name
+
+def gc_header_offset():
+    class GcHeader(ctypes.Structure):
+        _fields_ = get_gc_header()
+    return ctypes.sizeof(GcHeader)
+
 """
 
     epilog = """
@@ -228,8 +226,9 @@
             elif A is lltype.Float:
                 action = 'ctypes.c_double'
 
-            elif isinstance(A, lltype.Ptr) and A.TO is STR:
-                action = 'from_str'
+            elif isinstance(A, lltype.Ptr) and A == CCHARP:
+                action = 'identity'
+
             else:
                 assert A in self.TO_CTYPES
                 action = 'identity'
@@ -339,3 +338,35 @@
             return self[i]
         else:
             raise AttributeError, name
+
+def augment_entrypoint(translator, entrypoint):
+    bk = translator.annotator.bookkeeper
+    graph_entrypoint = bk.getdesc(entrypoint).getuniquegraph()
+    args, varargs, kwds, _ = inspect.getargspec(entrypoint)
+    assert varargs is None and kwds is None
+    args_s = [translator.annotator.binding(v) for v in graph_entrypoint.getargs()]
+    s_result = translator.annotator.binding(graph_entrypoint.getreturnvar())
+    converted = []
+    converted_s = []
+    for v, s_v  in zip(args, args_s):
+        if isinstance(s_v, annmodel.SomeString):
+            v = "charp2str(%s)" % v
+            s_v = annmodel.SomePtr(CCHARP)
+        converted.append(v)
+        converted_s.append(s_v)
+    code = """
+def new_entrypoint(%s):
+    return entrypoint(%s)
+""" % (", ".join(args), ", ".join(converted))
+    d = dict(entrypoint=entrypoint, charp2str=charp2str)
+    exec code in d
+
+    mixlevelannotator = MixLevelHelperAnnotator(translator.rtyper)
+    new_entrypoint = d['new_entrypoint']
+
+    mixlevelannotator.getgraph(new_entrypoint, converted_s, s_result)
+    mixlevelannotator.finish()
+
+    return new_entrypoint
+
+    



More information about the Pypy-commit mailing list