[pypy-svn] r75261 - in pypy/branch/multijit-3/pypy/jit/codewriter: . test
arigo at codespeak.net
arigo at codespeak.net
Fri Jun 11 11:52:44 CEST 2010
Author: arigo
Date: Fri Jun 11 11:52:42 2010
New Revision: 75261
Modified:
pypy/branch/multijit-3/pypy/jit/codewriter/call.py
pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py
pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py
pypy/branch/multijit-3/pypy/jit/codewriter/test/test_call.py
pypy/branch/multijit-3/pypy/jit/codewriter/test/test_codewriter.py
pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py
Log:
Fix codewriter to do (in a single pass) the transformation
for multiple JitDrivers.
Modified: pypy/branch/multijit-3/pypy/jit/codewriter/call.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/call.py (original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/call.py Fri Jun 11 11:52:42 2010
@@ -16,21 +16,22 @@
class CallControl(object):
virtualref_info = None # optionally set from outside
- virtualizable_info = None # optionally set from outside
- portal_runner_ptr = None # optionally set from outside
- def __init__(self, cpu=None, portal_graph=None):
+ def __init__(self, cpu=None, jitdrivers_sd=[]):
+ assert isinstance(jitdrivers_sd, list) # debugging
self.cpu = cpu
- self.portal_graph = portal_graph
+ self.jitdrivers_sd = jitdrivers_sd
self.jitcodes = {} # map {graph: jitcode}
self.unfinished_graphs = [] # list of graphs with pending jitcodes
- self.jitdriver = None
if hasattr(cpu, 'rtyper'): # for tests
self.rtyper = cpu.rtyper
translator = self.rtyper.annotator.translator
self.raise_analyzer = RaiseAnalyzer(translator)
self.readwrite_analyzer = ReadWriteAnalyzer(translator)
self.virtualizable_analyzer = VirtualizableAnalyzer(translator)
+ #
+ for index, jd in enumerate(jitdrivers_sd):
+ jd.index = index
def find_all_graphs(self, policy):
try:
@@ -41,8 +42,8 @@
def is_candidate(graph):
return policy.look_inside_graph(graph)
- assert self.portal_graph is not None
- todo = [self.portal_graph]
+ assert len(self.jitdrivers_sd) > 0
+ todo = [jd.portal_graph for jd in self.jitdrivers_sd]
if hasattr(self, 'rtyper'):
for oopspec_name, ll_args, ll_res in support.inline_calls_to:
c_func, _ = support.builtin_func_for_spec(self.rtyper,
@@ -122,7 +123,7 @@
def guess_call_kind(self, op, is_candidate=None):
if op.opname == 'direct_call':
funcptr = op.args[0].value
- if funcptr is self.portal_runner_ptr:
+ if self.jitdriver_sd_from_portal_runner_ptr(funcptr) is not None:
return 'recursive'
funcobj = get_funcobj(funcptr)
if getattr(funcobj, 'graph', None) is None:
@@ -143,6 +144,10 @@
# used only after find_all_graphs()
return graph in self.candidate_graphs
+ def grab_initial_jitcodes(self):
+ for jd in self.jitdrivers_sd:
+ jd.mainjitcode = self.get_jitcode(jd.portal_graph)
+
def enum_pending_graphs(self):
while self.unfinished_graphs:
graph = self.unfinished_graphs.pop()
@@ -241,12 +246,26 @@
return (effectinfo is None or
effectinfo.extraeffect >= EffectInfo.EF_CAN_RAISE)
- def found_jitdriver(self, jitdriver):
- if self.jitdriver is None:
- self.jitdriver = jitdriver
- else:
- assert self.jitdriver is jitdriver
+ def jitdriver_sd_from_portal_graph(self, graph):
+ for jd in self.jitdrivers_sd:
+ if jd.portal_graph is graph:
+ return jd
+ return None
- def getjitdriver(self):
- assert self.jitdriver is not None, "order dependency issue?"
- return self.jitdriver
+ def jitdriver_sd_from_portal_runner_ptr(self, funcptr):
+ for jd in self.jitdrivers_sd:
+ if funcptr is jd.portal_runner_ptr:
+ return jd
+ return None
+
+ def get_vinfo(self, VTYPEPTR):
+ seen = set()
+ for jd in self.jitdrivers_sd:
+ if jd.virtualizable_info is not None:
+ if jd.virtualizable_info.is_vtypeptr(VTYPEPTR):
+ seen.add(jd.virtualizable_info)
+ if seen:
+ assert len(seen) == 1
+ return seen.pop()
+ else:
+ return None
Modified: pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py (original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py Fri Jun 11 11:52:42 2010
@@ -14,29 +14,30 @@
class CodeWriter(object):
callcontrol = None # for tests
- def __init__(self, cpu=None, maingraph=None):
+ def __init__(self, cpu=None, jitdrivers_sd=[]):
self.cpu = cpu
self.assembler = Assembler()
- self.portal_graph = maingraph
- self.callcontrol = CallControl(cpu, maingraph)
+ self.callcontrol = CallControl(cpu, jitdrivers_sd)
+ self._seen_files = set()
def transform_func_to_jitcode(self, func, values, type_system='lltype'):
"""For testing."""
rtyper = support.annotate(func, values, type_system=type_system)
graph = rtyper.annotator.translator.graphs[0]
jitcode = JitCode("test")
- self.transform_graph_to_jitcode(graph, jitcode, True, True)
+ self.transform_graph_to_jitcode(graph, jitcode, True)
return jitcode
- def transform_graph_to_jitcode(self, graph, jitcode, portal, verbose):
+ def transform_graph_to_jitcode(self, graph, jitcode, verbose):
"""Transform a graph into a JitCode containing the same bytecode
in a different format.
"""
+ portal_jd = self.callcontrol.jitdriver_sd_from_portal_graph(graph)
graph = copygraph(graph, shallowvars=True)
#
# step 1: mangle the graph so that it contains the final instructions
# that we want in the JitCode, but still as a control flow graph
- transform_graph(graph, self.cpu, self.callcontrol, portal)
+ transform_graph(graph, self.cpu, self.callcontrol)
#
# step 2: perform register allocation on it
regallocs = {}
@@ -59,16 +60,14 @@
self.assembler.assemble(ssarepr, jitcode)
#
# print the resulting assembler
- self.print_ssa_repr(ssarepr, portal, verbose)
+ self.print_ssa_repr(ssarepr, portal_jd, verbose)
def make_jitcodes(self, verbose=False):
log.info("making JitCodes...")
- maingraph = self.portal_graph
- self.mainjitcode = self.callcontrol.get_jitcode(maingraph)
+ self.callcontrol.grab_initial_jitcodes()
count = 0
for graph, jitcode in self.callcontrol.enum_pending_graphs():
- self.transform_graph_to_jitcode(graph, jitcode,
- graph is maingraph, verbose)
+ self.transform_graph_to_jitcode(graph, jitcode, verbose)
count += 1
if not count % 500:
log.info("Produced %d jitcodes" % count)
@@ -76,33 +75,35 @@
log.info("there are %d JitCode instances." % count)
def setup_vrefinfo(self, vrefinfo):
+ # must be called at most once
+ assert self.callcontrol.virtualref_info is None
self.callcontrol.virtualref_info = vrefinfo
- def setup_virtualizable_info(self, vinfo):
- self.callcontrol.virtualizable_info = vinfo
-
- def setup_portal_runner_ptr(self, portal_runner_ptr):
- self.callcontrol.portal_runner_ptr = portal_runner_ptr
+ def setup_jitdriver(self, jitdriver_sd):
+ # Must be called once per jitdriver. Usually jitdriver_sd is an
+ # instance of pypy.jit.metainterp.jitdriver.JitDriverStaticData.
+ self.callcontrol.jitdrivers_sd.append(jitdriver_sd)
def find_all_graphs(self, policy):
return self.callcontrol.find_all_graphs(policy)
- def print_ssa_repr(self, ssarepr, portal, verbose):
+ def print_ssa_repr(self, ssarepr, portal_jitdriver, verbose):
if verbose:
print '%s:' % (ssarepr.name,)
print format_assembler(ssarepr)
else:
dir = udir.ensure("jitcodes", dir=1)
- if portal:
- name = "00_portal_runner"
+ if portal_jitdriver:
+ name = "%02d_portal_runner" % (portal_jitdriver.index,)
elif ssarepr.name and ssarepr.name != '?':
name = ssarepr.name
else:
name = 'unnamed' % id(ssarepr)
i = 1
extra = ''
- while dir.join(name+extra).check(exists=1):
+ while name+extra in self._seen_files:
i += 1
extra = '.%d' % i
+ self._seen_files.add(name+extra)
dir.join(name+extra).write(format_assembler(ssarepr))
log.dot()
Modified: pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py (original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py Fri Jun 11 11:52:42 2010
@@ -13,12 +13,12 @@
from pypy.translator.simplify import get_funcobj
-def transform_graph(graph, cpu=None, callcontrol=None, portal=True):
+def transform_graph(graph, cpu=None, callcontrol=None):
"""Transform a control flow graph to make it suitable for
being flattened in a JitCode.
"""
t = Transformer(cpu, callcontrol)
- t.transform(graph, portal)
+ t.transform(graph)
class Transformer(object):
@@ -27,9 +27,8 @@
self.cpu = cpu
self.callcontrol = callcontrol
- def transform(self, graph, portal):
+ def transform(self, graph):
self.graph = graph
- self.portal = portal
for block in list(graph.iterblocks()):
self.optimize_block(block)
@@ -317,10 +316,12 @@
return op1
def handle_recursive_call(self, op):
- ops = self.promote_greens(op.args[1:])
- targetgraph = self.callcontrol.portal_graph
- num_green_args = len(self.callcontrol.getjitdriver().greens)
- args = (self.make_three_lists(op.args[1:1+num_green_args]) +
+ jitdriver_sd = self.callcontrol.jitdriver_sd_from_portal_runner_ptr(
+ op.args[0])
+ ops = self.promote_greens(op.args[1:], jitdriver_sd.jitdriver)
+ num_green_args = len(jitdriver_sd.jitdriver.greens)
+ args = ([Constant(jitdriver_sd.index, lltype.Signed)] +
+ self.make_three_lists(op.args[1:1+num_green_args]) +
self.make_three_lists(op.args[1+num_green_args:]))
kind = getkind(op.result.concretetype)[0]
op0 = SpaceOperation('recursive_call_%s' % kind, args, op.result)
@@ -475,14 +476,14 @@
# check for virtualizable
try:
if self.is_virtualizable_getset(op):
- descr = self.get_virtualizable_field_descr(op.args[1].value)
+ descr = self.get_virtualizable_field_descr(op)
kind = getkind(RESULT)[0]
return [SpaceOperation('-live-', [], None),
SpaceOperation('getfield_vable_%s' % kind,
[v_inst, descr], op.result)]
- except VirtualizableArrayField:
+ except VirtualizableArrayField, e:
# xxx hack hack hack
- vinfo = self.callcontrol.virtualizable_info
+ vinfo = e.args[1]
arrayindex = vinfo.array_field_counter[op.args[1].value]
arrayfielddescr = vinfo.array_field_descrs[arrayindex]
arraydescr = vinfo.array_descrs[arrayindex]
@@ -519,7 +520,7 @@
return
# check for virtualizable
if self.is_virtualizable_getset(op):
- descr = self.get_virtualizable_field_descr(op.args[1].value)
+ descr = self.get_virtualizable_field_descr(op)
kind = getkind(RESULT)[0]
return [SpaceOperation('-live-', [], None),
SpaceOperation('setfield_vable_%s' % kind,
@@ -536,21 +537,23 @@
return (op.args[1].value == 'typeptr' and
op.args[0].concretetype.TO._hints.get('typeptr'))
+ def get_vinfo(self, v_virtualizable):
+ if self.callcontrol is None: # for tests
+ return None
+ return self.callcontrol.get_vinfo(v_virtualizable.concretetype)
+
def is_virtualizable_getset(self, op):
# every access of an object of exactly the type VTYPEPTR is
# likely to be a virtualizable access, but we still have to
# check it in pyjitpl.py.
- try:
- vinfo = self.callcontrol.virtualizable_info
- except AttributeError:
- return False
- if vinfo is None or not vinfo.is_vtypeptr(op.args[0].concretetype):
+ vinfo = self.get_vinfo(op.args[0])
+ if vinfo is None:
return False
res = False
if op.args[1].value in vinfo.static_field_to_extra_box:
res = True
if op.args[1].value in vinfo.array_fields:
- res = VirtualizableArrayField(self.graph)
+ res = VirtualizableArrayField(self.graph, vinfo)
if res:
flags = self.vable_flags[op.args[0]]
@@ -560,8 +563,9 @@
raise res
return res
- def get_virtualizable_field_descr(self, fieldname):
- vinfo = self.callcontrol.virtualizable_info
+ def get_virtualizable_field_descr(self, op):
+ fieldname = op.args[1].value
+ vinfo = self.get_vinfo(op.args[0])
index = vinfo.static_field_to_extra_box[fieldname]
return vinfo.static_field_descrs[index]
@@ -751,9 +755,10 @@
return Constant(value, lltype.Bool)
return op
- def promote_greens(self, args):
+ def promote_greens(self, args, jitdriver):
ops = []
- num_green_args = len(self.callcontrol.getjitdriver().greens)
+ num_green_args = len(jitdriver.greens)
+ assert len(args) == num_green_args + len(jitdriver.reds)
for v in args[:num_green_args]:
if isinstance(v, Variable) and v.concretetype is not lltype.Void:
kind = getkind(v.concretetype)
@@ -763,20 +768,19 @@
return ops
def rewrite_op_jit_marker(self, op):
- self.callcontrol.found_jitdriver(op.args[1].value)
key = op.args[0].value
- return getattr(self, 'handle_jit_marker__%s' % key)(op)
+ jitdriver = op.args[1].value
+ return getattr(self, 'handle_jit_marker__%s' % key)(op, jitdriver)
- def handle_jit_marker__jit_merge_point(self, op):
- assert self.portal, "jit_merge_point in non-main graph!"
- ops = self.promote_greens(op.args[2:])
- num_green_args = len(self.callcontrol.getjitdriver().greens)
+ def handle_jit_marker__jit_merge_point(self, op, jitdriver):
+ ops = self.promote_greens(op.args[2:], jitdriver)
+ num_green_args = len(jitdriver.greens)
args = (self.make_three_lists(op.args[2:2+num_green_args]) +
self.make_three_lists(op.args[2+num_green_args:]))
op1 = SpaceOperation('jit_merge_point', args, None)
return ops + [op1]
- def handle_jit_marker__can_enter_jit(self, op):
+ def handle_jit_marker__can_enter_jit(self, op, jitdriver):
return SpaceOperation('can_enter_jit', [], None)
def rewrite_op_debug_assert(self, op):
@@ -975,9 +979,8 @@
def rewrite_op_jit_force_virtualizable(self, op):
# this one is for virtualizables
- vinfo = self.callcontrol.virtualizable_info
+ vinfo = self.get_vinfo(op.args[0])
assert vinfo is not None
- assert vinfo.is_vtypeptr(op.args[0].concretetype)
self.vable_flags[op.args[0]] = op.args[2].value
return []
Modified: pypy/branch/multijit-3/pypy/jit/codewriter/test/test_call.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/test/test_call.py (original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/test/test_call.py Fri Jun 11 11:52:42 2010
@@ -52,13 +52,19 @@
# ____________________________________________________________
+class FakeJitDriverSD:
+ def __init__(self, portal_graph):
+ self.portal_graph = portal_graph
+ self.portal_runner_ptr = "???"
+
def test_find_all_graphs():
def g(x):
return x + 2
def f(x):
return g(x) + 1
rtyper = support.annotate(f, [7])
- cc = CallControl(portal_graph=rtyper.annotator.translator.graphs[0])
+ jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+ cc = CallControl(jitdrivers_sd=[jitdriver_sd])
res = cc.find_all_graphs(FakePolicy())
funcs = set([graph.func for graph in res])
assert funcs == set([f, g])
@@ -69,7 +75,8 @@
def f(x):
return g(x) + 1
rtyper = support.annotate(f, [7])
- cc = CallControl(portal_graph=rtyper.annotator.translator.graphs[0])
+ jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+ cc = CallControl(jitdrivers_sd=[jitdriver_sd])
class CustomFakePolicy:
def look_inside_graph(self, graph):
assert graph.name == 'g'
@@ -83,10 +90,11 @@
def test_guess_call_kind_and_calls_from_graphs():
class portal_runner_obj:
graph = object()
+ class FakeJitDriverSD:
+ portal_runner_ptr = portal_runner_obj
g = object()
g1 = object()
- cc = CallControl()
- cc.portal_runner_ptr = portal_runner_obj
+ cc = CallControl(jitdrivers_sd=[FakeJitDriverSD()])
cc.candidate_graphs = [g, g1]
op = SpaceOperation('direct_call', [Constant(portal_runner_obj)],
Modified: pypy/branch/multijit-3/pypy/jit/codewriter/test/test_codewriter.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/test/test_codewriter.py (original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/test/test_codewriter.py Fri Jun 11 11:52:42 2010
@@ -35,6 +35,12 @@
def look_inside_graph(self, graph):
return graph.name != 'dont_look'
+class FakeJitDriverSD:
+ def __init__(self, portal_graph):
+ self.portal_graph = portal_graph
+ self.portal_runner_ptr = "???"
+ self.virtualizable_info = None
+
def test_loop():
def f(a, b):
@@ -70,11 +76,11 @@
def fff(a, b):
return ggg(b) - ggg(a)
rtyper = support.annotate(fff, [35, 42])
- maingraph = rtyper.annotator.translator.graphs[0]
- cw = CodeWriter(FakeCPU(rtyper), maingraph)
+ jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+ cw = CodeWriter(FakeCPU(rtyper), [jitdriver_sd])
cw.find_all_graphs(FakePolicy())
cw.make_jitcodes(verbose=True)
- jitcode = cw.mainjitcode
+ jitcode = jitdriver_sd.mainjitcode
print jitcode.dump()
[jitcode2] = cw.assembler.descrs
print jitcode2.dump()
@@ -117,7 +123,7 @@
return x().id + y().id + dont_look(n)
rtyper = support.annotate(f, [35])
maingraph = rtyper.annotator.translator.graphs[0]
- cw = CodeWriter(FakeCPU(rtyper), maingraph)
+ cw = CodeWriter(FakeCPU(rtyper), [FakeJitDriverSD(maingraph)])
cw.find_all_graphs(FakePolicy())
cw.make_jitcodes(verbose=True)
#
@@ -144,10 +150,10 @@
def f(n):
return abs(n)
rtyper = support.annotate(f, [35])
- maingraph = rtyper.annotator.translator.graphs[0]
- cw = CodeWriter(FakeCPU(rtyper), maingraph)
+ jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+ cw = CodeWriter(FakeCPU(rtyper), [jitdriver_sd])
cw.find_all_graphs(FakePolicy())
cw.make_jitcodes(verbose=True)
#
- s = cw.mainjitcode.dump()
+ s = jitdriver_sd.mainjitcode.dump()
assert "inline_call_ir_i <JitCode '_ll_1_int_abs__Signed'>" in s
Modified: pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py (original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py Fri Jun 11 11:52:42 2010
@@ -68,11 +68,8 @@
return FakeDescr()
def calldescr_canraise(self, calldescr):
return calldescr is not self._descr_cannot_raise
- def found_jitdriver(self, jitdriver):
- assert isinstance(jitdriver, JitDriver)
- self.jitdriver = jitdriver
- def getjitdriver(self):
- return self.jitdriver
+ def get_vinfo(self, VTYPEPTR):
+ return None
class FakeCallControlWithVRefInfo:
class virtualref_info:
More information about the Pypy-commit
mailing list