[pypy-svn] pypy default: Patch by qbproger: write square roots with the x86 SSE2
arigo
commits-noreply at bitbucket.org
Sat May 7 12:07:42 CEST 2011
Author: Armin Rigo <arigo at tunes.org>
Branch:
Changeset: r43935:9439564ba9b3
Date: 2011-05-07 12:07 +0200
http://bitbucket.org/pypy/pypy/changeset/9439564ba9b3/
Log: Patch by qbproger: write square roots with the x86 SSE2 instruction
SQRTSD. Also cleans up a bit the error checking around math.sqrt().
diff --git a/pypy/jit/backend/x86/assembler.py b/pypy/jit/backend/x86/assembler.py
--- a/pypy/jit/backend/x86/assembler.py
+++ b/pypy/jit/backend/x86/assembler.py
@@ -832,6 +832,11 @@
effectinfo = op.getdescr().get_extra_info()
oopspecindex = effectinfo.oopspecindex
genop_llong_list[oopspecindex](self, op, arglocs, resloc)
+
+ def regalloc_perform_math(self, op, arglocs, resloc):
+ effectinfo = op.getdescr().get_extra_info()
+ oopspecindex = effectinfo.oopspecindex
+ genop_math_list[oopspecindex](self, op, arglocs, resloc)
def regalloc_perform_with_guard(self, op, guard_op, faillocs,
arglocs, resloc, current_depths):
@@ -1119,6 +1124,9 @@
genop_guard_float_eq = _cmpop_guard_float("E", "E", "NE","NE")
genop_guard_float_gt = _cmpop_guard_float("A", "B", "BE","AE")
genop_guard_float_ge = _cmpop_guard_float("AE","BE", "B", "A")
+
+ def genop_math_sqrt(self, op, arglocs, resloc):
+ self.mc.SQRTSD(arglocs[0], resloc)
def genop_guard_float_ne(self, op, guard_op, guard_token, arglocs, result_loc):
guard_opnum = guard_op.getopnum()
@@ -2158,6 +2166,7 @@
genop_discard_list = [Assembler386.not_implemented_op_discard] * rop._LAST
genop_list = [Assembler386.not_implemented_op] * rop._LAST
genop_llong_list = {}
+genop_math_list = {}
genop_guard_list = [Assembler386.not_implemented_op_guard] * rop._LAST
for name, value in Assembler386.__dict__.iteritems():
@@ -2173,6 +2182,10 @@
opname = name[len('genop_llong_'):]
num = getattr(EffectInfo, 'OS_LLONG_' + opname.upper())
genop_llong_list[num] = value
+ elif name.startswith('genop_math_'):
+ opname = name[len('genop_math_'):]
+ num = getattr(EffectInfo, 'OS_MATH_' + opname.upper())
+ genop_math_list[num] = value
elif name.startswith('genop_'):
opname = name[len('genop_'):]
num = getattr(rop, opname.upper())
diff --git a/pypy/jit/backend/x86/regalloc.py b/pypy/jit/backend/x86/regalloc.py
--- a/pypy/jit/backend/x86/regalloc.py
+++ b/pypy/jit/backend/x86/regalloc.py
@@ -328,6 +328,11 @@
if not we_are_translated():
self.assembler.dump('%s <- %s(%s)' % (result_loc, op, arglocs))
self.assembler.regalloc_perform_llong(op, arglocs, result_loc)
+
+ def PerformMath(self, op, arglocs, result_loc):
+ if not we_are_translated():
+ self.assembler.dump('%s <- %s(%s)' % (result_loc, op, arglocs))
+ self.assembler.regalloc_perform_math(op, arglocs, result_loc)
def locs_for_fail(self, guard_op):
return [self.loc(v) for v in guard_op.getfailargs()]
@@ -661,15 +666,13 @@
consider_float_gt = _consider_float_cmp
consider_float_ge = _consider_float_cmp
- def consider_float_neg(self, op):
+ def _consider_float_unary_op(self, op):
loc0 = self.xrm.force_result_in_reg(op.result, op.getarg(0))
self.Perform(op, [loc0], loc0)
self.xrm.possibly_free_var(op.getarg(0))
-
- def consider_float_abs(self, op):
- loc0 = self.xrm.force_result_in_reg(op.result, op.getarg(0))
- self.Perform(op, [loc0], loc0)
- self.xrm.possibly_free_var(op.getarg(0))
+
+ consider_float_neg = _consider_float_unary_op
+ consider_float_abs = _consider_float_unary_op
def consider_cast_float_to_int(self, op):
loc0 = self.xrm.make_sure_var_in_reg(op.getarg(0))
@@ -755,6 +758,11 @@
loc1 = self.rm.make_sure_var_in_reg(op.getarg(1))
self.PerformLLong(op, [loc1], loc0)
self.rm.possibly_free_vars_for_op(op)
+
+ def _consider_math_sqrt(self, op):
+ loc0 = self.xrm.force_result_in_reg(op.result, op.getarg(1))
+ self.PerformMath(op, [loc0], loc0)
+ self.xrm.possibly_free_var(op.getarg(1))
def _call(self, op, arglocs, force_store=[], guard_not_forced_op=None):
save_all_regs = guard_not_forced_op is not None
@@ -791,12 +799,12 @@
guard_not_forced_op=guard_not_forced_op)
def consider_call(self, op):
- if IS_X86_32:
- # support for some of the llong operations,
- # which only exist on x86-32
- effectinfo = op.getdescr().get_extra_info()
- if effectinfo is not None:
- oopspecindex = effectinfo.oopspecindex
+ effectinfo = op.getdescr().get_extra_info()
+ if effectinfo is not None:
+ oopspecindex = effectinfo.oopspecindex
+ if IS_X86_32:
+ # support for some of the llong operations,
+ # which only exist on x86-32
if oopspecindex in (EffectInfo.OS_LLONG_ADD,
EffectInfo.OS_LLONG_SUB,
EffectInfo.OS_LLONG_AND,
@@ -815,7 +823,8 @@
if oopspecindex == EffectInfo.OS_LLONG_LT:
if self._maybe_consider_llong_lt(op):
return
- #
+ if oopspecindex == EffectInfo.OS_MATH_SQRT:
+ return self._consider_math_sqrt(op)
self._consider_call(op)
def consider_call_may_force(self, op, guard_op):
diff --git a/pypy/jit/backend/x86/regloc.py b/pypy/jit/backend/x86/regloc.py
--- a/pypy/jit/backend/x86/regloc.py
+++ b/pypy/jit/backend/x86/regloc.py
@@ -515,6 +515,8 @@
UCOMISD = _binaryop('UCOMISD')
CVTSI2SD = _binaryop('CVTSI2SD')
CVTTSD2SI = _binaryop('CVTTSD2SI')
+
+ SQRTSD = _binaryop('SQRTSD')
ANDPD = _binaryop('ANDPD')
XORPD = _binaryop('XORPD')
diff --git a/pypy/jit/backend/x86/rx86.py b/pypy/jit/backend/x86/rx86.py
--- a/pypy/jit/backend/x86/rx86.py
+++ b/pypy/jit/backend/x86/rx86.py
@@ -691,6 +691,8 @@
define_modrm_modes('MOVSD_x*', ['\xF2', rex_nw, '\x0F\x10', register(1,8)], regtype='XMM')
define_modrm_modes('MOVSD_*x', ['\xF2', rex_nw, '\x0F\x11', register(2,8)], regtype='XMM')
+define_modrm_modes('SQRTSD_x*', ['\xF2', rex_nw, '\x0F\x51', register(1,8)], regtype='XMM')
+
#define_modrm_modes('XCHG_r*', [rex_w, '\x87', register(1, 8)])
define_modrm_modes('ADDSD_x*', ['\xF2', rex_nw, '\x0F\x58', register(1, 8)], regtype='XMM')
diff --git a/pypy/jit/codewriter/effectinfo.py b/pypy/jit/codewriter/effectinfo.py
--- a/pypy/jit/codewriter/effectinfo.py
+++ b/pypy/jit/codewriter/effectinfo.py
@@ -72,6 +72,8 @@
OS_LLONG_UGE = 91
OS_LLONG_URSHIFT = 92
OS_LLONG_FROM_UINT = 93
+ #
+ OS_MATH_SQRT = 100
def __new__(cls, readonly_descrs_fields,
write_descrs_fields, write_descrs_arrays,
diff --git a/pypy/jit/codewriter/jtransform.py b/pypy/jit/codewriter/jtransform.py
--- a/pypy/jit/codewriter/jtransform.py
+++ b/pypy/jit/codewriter/jtransform.py
@@ -351,6 +351,8 @@
prepare = self._handle_jit_call
elif oopspec_name.startswith('libffi_'):
prepare = self._handle_libffi_call
+ elif oopspec_name.startswith('math.sqrt'):
+ prepare = self._handle_math_sqrt_call
else:
prepare = self.prepare_builtin_call
try:
@@ -1360,6 +1362,13 @@
assert vinfo is not None
self.vable_flags[op.args[0]] = op.args[2].value
return []
+
+ # ---------
+ # ll_math.sqrt_nonneg()
+
+ def _handle_math_sqrt_call(self, op, oopspec_name, args):
+ return self._handle_oopspec_call(op, args, EffectInfo.OS_MATH_SQRT,
+ EffectInfo.EF_PURE)
# ____________________________________________________________
diff --git a/pypy/jit/codewriter/support.py b/pypy/jit/codewriter/support.py
--- a/pypy/jit/codewriter/support.py
+++ b/pypy/jit/codewriter/support.py
@@ -4,6 +4,7 @@
from pypy.rpython import rlist
from pypy.rpython.lltypesystem import rstr as ll_rstr, rdict as ll_rdict
from pypy.rpython.lltypesystem import rlist as lltypesystem_rlist
+from pypy.rpython.lltypesystem.module import ll_math
from pypy.rpython.lltypesystem.lloperation import llop
from pypy.rpython.ootypesystem import rdict as oo_rdict
from pypy.rpython.llinterp import LLInterpreter
@@ -221,6 +222,11 @@
return -x
else:
return x
+
+# math support
+# ------------
+
+_ll_1_ll_math_ll_math_sqrt = ll_math.ll_math_sqrt
# long long support
@@ -388,6 +394,7 @@
('int_mod_zer', [lltype.Signed, lltype.Signed], lltype.Signed),
('int_lshift_ovf', [lltype.Signed, lltype.Signed], lltype.Signed),
('int_abs', [lltype.Signed], lltype.Signed),
+ ('ll_math.ll_math_sqrt', [lltype.Float], lltype.Float),
]
diff --git a/pypy/jit/codewriter/test/test_jtransform.py b/pypy/jit/codewriter/test/test_jtransform.py
--- a/pypy/jit/codewriter/test/test_jtransform.py
+++ b/pypy/jit/codewriter/test/test_jtransform.py
@@ -5,6 +5,7 @@
from pypy.jit.codewriter.jtransform import Transformer
from pypy.jit.metainterp.history import getkind
from pypy.rpython.lltypesystem import lltype, llmemory, rclass, rstr, rlist
+from pypy.rpython.lltypesystem.module import ll_math
from pypy.translator.unsimplify import varoftype
from pypy.jit.codewriter import heaptracker, effectinfo
from pypy.jit.codewriter.flatten import ListOfKind
@@ -98,7 +99,9 @@
PUNICODE = lltype.Ptr(rstr.UNICODE)
INT = lltype.Signed
UNICHAR = lltype.UniChar
+ FLOAT = lltype.Float
argtypes = {
+ EI.OS_MATH_SQRT: ([FLOAT], FLOAT),
EI.OS_STR2UNICODE:([PSTR], PUNICODE),
EI.OS_STR_CONCAT: ([PSTR, PSTR], PSTR),
EI.OS_STR_SLICE: ([PSTR, INT, INT], PSTR),
@@ -947,3 +950,22 @@
assert op1.args[1] == 'calldescr-%d' % effectinfo.EffectInfo.OS_ARRAYCOPY
assert op1.args[2] == ListOfKind('int', [v3, v4, v5])
assert op1.args[3] == ListOfKind('ref', [v1, v2])
+
+def test_math_sqrt():
+ # test that the oopspec is present and correctly transformed
+ FLOAT = lltype.Float
+ FUNC = lltype.FuncType([FLOAT], FLOAT)
+ func = lltype.functionptr(FUNC, 'll_math',
+ _callable=ll_math.sqrt_nonneg)
+ v1 = varoftype(FLOAT)
+ v2 = varoftype(FLOAT)
+ op = SpaceOperation('direct_call', [const(func), v1], v2)
+ tr = Transformer(FakeCPU(), FakeBuiltinCallControl())
+ op1 = tr.rewrite_operation(op)
+ assert op1.opname == 'residual_call_irf_f'
+ assert op1.args[0].value == func
+ assert op1.args[1] == 'calldescr-%d' % effectinfo.EffectInfo.OS_MATH_SQRT
+ assert op1.args[2] == ListOfKind("int", [])
+ assert op1.args[3] == ListOfKind("ref", [])
+ assert op1.args[4] == ListOfKind('float', [v1])
+ assert op1.result == v2
diff --git a/pypy/jit/metainterp/test/support.py b/pypy/jit/metainterp/test/support.py
--- a/pypy/jit/metainterp/test/support.py
+++ b/pypy/jit/metainterp/test/support.py
@@ -9,6 +9,7 @@
from pypy.jit.metainterp.warmstate import set_future_value
from pypy.jit.codewriter.policy import JitPolicy
from pypy.jit.codewriter import longlong
+from pypy.rlib.rfloat import isinf, isnan
def _get_jitcodes(testself, CPUClass, func, values, type_system,
supports_longlong=False, **kwds):
@@ -181,10 +182,10 @@
result1 = _run_with_blackhole(self, args)
# try to run it with pyjitpl.py
result2 = _run_with_pyjitpl(self, args)
- assert result1 == result2
+ assert result1 == result2 or isnan(result1) and isnan(result2)
# try to run it by running the code compiled just before
result3 = _run_with_machine_code(self, args)
- assert result1 == result3 or result3 == NotImplemented
+ assert result1 == result3 or result3 == NotImplemented or isnan(result1) and isnan(result3)
#
if (longlong.supports_longlong and
isinstance(result1, longlong.r_float_storage)):
diff --git a/pypy/rpython/extfuncregistry.py b/pypy/rpython/extfuncregistry.py
--- a/pypy/rpython/extfuncregistry.py
+++ b/pypy/rpython/extfuncregistry.py
@@ -45,6 +45,9 @@
register_external(math.floor, [float], float,
export_name="ll_math.ll_math_floor", sandboxsafe=True,
llimpl=ll_math.ll_math_floor)
+register_external(math.sqrt, [float], float,
+ export_name="ll_math.ll_math_sqrt", sandboxsafe=True,
+ llimpl=ll_math.ll_math_sqrt)
complex_math_functions = [
('frexp', [float], (float, int)),
diff --git a/pypy/rpython/lltypesystem/module/ll_math.py b/pypy/rpython/lltypesystem/module/ll_math.py
--- a/pypy/rpython/lltypesystem/module/ll_math.py
+++ b/pypy/rpython/lltypesystem/module/ll_math.py
@@ -9,7 +9,7 @@
from pypy.rlib import jit, rposix
from pypy.translator.tool.cbuild import ExternalCompilationInfo
from pypy.translator.platform import platform
-from pypy.rlib.rfloat import isinf, isnan, INFINITY, NAN
+from pypy.rlib.rfloat import isfinite, isinf, isnan, INFINITY, NAN
if sys.platform == "win32":
if platform.name == "msvc":
@@ -69,6 +69,13 @@
[rffi.DOUBLE, rffi.DOUBLE], rffi.DOUBLE)
math_floor = llexternal('floor', [rffi.DOUBLE], rffi.DOUBLE, pure_function=True)
+math_sqrt = llexternal('sqrt', [rffi.DOUBLE], rffi.DOUBLE)
+
+ at jit.purefunction
+def sqrt_nonneg(x):
+ return math_sqrt(x)
+sqrt_nonneg.oopspec = "math.sqrt_nonneg(x)"
+
# ____________________________________________________________
#
# Error handling functions
@@ -319,6 +326,15 @@
_likely_raise(errno, r)
return r
+def ll_math_sqrt(x):
+ if x < 0.0:
+ raise ValueError, "math domain error"
+
+ if isfinite(x):
+ return sqrt_nonneg(x)
+
+ return x # +inf or nan
+
# ____________________________________________________________
#
# Default implementations
@@ -357,7 +373,7 @@
unary_math_functions = [
'acos', 'asin', 'atan',
'ceil', 'cos', 'cosh', 'exp', 'fabs',
- 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'log', 'log10',
+ 'sin', 'sinh', 'tan', 'tanh', 'log', 'log10',
'acosh', 'asinh', 'atanh', 'log1p', 'expm1',
]
unary_math_functions_can_overflow = [
More information about the Pypy-commit
mailing list