[Python-checkins] bpo-31778: Make ast.literal_eval() more strict. (#4035)

Serhiy Storchaka webhook-mailer at python.org
Thu Jan 4 04:15:47 EST 2018


https://github.com/python/cpython/commit/d8ac4d1d5ac256ebf3d8d38c226049abec82a2a0
commit: d8ac4d1d5ac256ebf3d8d38c226049abec82a2a0
branch: master
author: Serhiy Storchaka <storchaka at gmail.com>
committer: GitHub <noreply at github.com>
date: 2018-01-04T11:15:39+02:00
summary:

bpo-31778: Make ast.literal_eval() more strict. (#4035)

Addition and subtraction of arbitrary numbers no longer allowed.

files:
A Misc/NEWS.d/next/Library/2017-10-18-17-29-30.bpo-31778.B6vAkP.rst
M Lib/ast.py
M Lib/test/test_ast.py
M Lib/test/test_inspect.py

diff --git a/Lib/ast.py b/Lib/ast.py
index 070c2bee7f9..2ecb03f38bc 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -35,8 +35,6 @@ def parse(source, filename='<unknown>', mode='exec'):
     return compile(source, filename, mode, PyCF_ONLY_AST)
 
 
-_NUM_TYPES = (int, float, complex)
-
 def literal_eval(node_or_string):
     """
     Safely evaluate an expression node or a string containing a Python
@@ -48,6 +46,21 @@ def literal_eval(node_or_string):
         node_or_string = parse(node_or_string, mode='eval')
     if isinstance(node_or_string, Expression):
         node_or_string = node_or_string.body
+    def _convert_num(node):
+        if isinstance(node, Constant):
+            if isinstance(node.value, (int, float, complex)):
+                return node.value
+        elif isinstance(node, Num):
+            return node.n
+        raise ValueError('malformed node or string: ' + repr(node))
+    def _convert_signed_num(node):
+        if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
+            operand = _convert_num(node.operand)
+            if isinstance(node.op, UAdd):
+                return + operand
+            else:
+                return - operand
+        return _convert_num(node)
     def _convert(node):
         if isinstance(node, Constant):
             return node.value
@@ -62,26 +75,19 @@ def _convert(node):
         elif isinstance(node, Set):
             return set(map(_convert, node.elts))
         elif isinstance(node, Dict):
-            return dict((_convert(k), _convert(v)) for k, v
-                        in zip(node.keys, node.values))
+            return dict(zip(map(_convert, node.keys),
+                            map(_convert, node.values)))
         elif isinstance(node, NameConstant):
             return node.value
-        elif isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
-            operand = _convert(node.operand)
-            if isinstance(operand, _NUM_TYPES):
-                if isinstance(node.op, UAdd):
-                    return + operand
-                else:
-                    return - operand
         elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
-            left = _convert(node.left)
-            right = _convert(node.right)
-            if isinstance(left, _NUM_TYPES) and isinstance(right, _NUM_TYPES):
+            left = _convert_signed_num(node.left)
+            right = _convert_num(node.right)
+            if isinstance(left, (int, float)) and isinstance(right, complex):
                 if isinstance(node.op, Add):
                     return left + right
                 else:
                     return left - right
-        raise ValueError('malformed node or string: ' + repr(node))
+        return _convert_signed_num(node)
     return _convert(node_or_string)
 
 
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index aa53503e3b5..67f363ad31f 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -551,14 +551,37 @@ def test_literal_eval(self):
         self.assertEqual(ast.literal_eval('{1, 2, 3}'), {1, 2, 3})
         self.assertEqual(ast.literal_eval('b"hi"'), b"hi")
         self.assertRaises(ValueError, ast.literal_eval, 'foo()')
+        self.assertEqual(ast.literal_eval('6'), 6)
+        self.assertEqual(ast.literal_eval('+6'), 6)
         self.assertEqual(ast.literal_eval('-6'), -6)
-        self.assertEqual(ast.literal_eval('-6j+3'), 3-6j)
         self.assertEqual(ast.literal_eval('3.25'), 3.25)
-
-    def test_literal_eval_issue4907(self):
-        self.assertEqual(ast.literal_eval('2j'), 2j)
-        self.assertEqual(ast.literal_eval('10 + 2j'), 10 + 2j)
-        self.assertEqual(ast.literal_eval('1.5 - 2j'), 1.5 - 2j)
+        self.assertEqual(ast.literal_eval('+3.25'), 3.25)
+        self.assertEqual(ast.literal_eval('-3.25'), -3.25)
+        self.assertEqual(repr(ast.literal_eval('-0.0')), '-0.0')
+        self.assertRaises(ValueError, ast.literal_eval, '++6')
+        self.assertRaises(ValueError, ast.literal_eval, '+True')
+        self.assertRaises(ValueError, ast.literal_eval, '2+3')
+
+    def test_literal_eval_complex(self):
+        # Issue #4907
+        self.assertEqual(ast.literal_eval('6j'), 6j)
+        self.assertEqual(ast.literal_eval('-6j'), -6j)
+        self.assertEqual(ast.literal_eval('6.75j'), 6.75j)
+        self.assertEqual(ast.literal_eval('-6.75j'), -6.75j)
+        self.assertEqual(ast.literal_eval('3+6j'), 3+6j)
+        self.assertEqual(ast.literal_eval('-3+6j'), -3+6j)
+        self.assertEqual(ast.literal_eval('3-6j'), 3-6j)
+        self.assertEqual(ast.literal_eval('-3-6j'), -3-6j)
+        self.assertEqual(ast.literal_eval('3.25+6.75j'), 3.25+6.75j)
+        self.assertEqual(ast.literal_eval('-3.25+6.75j'), -3.25+6.75j)
+        self.assertEqual(ast.literal_eval('3.25-6.75j'), 3.25-6.75j)
+        self.assertEqual(ast.literal_eval('-3.25-6.75j'), -3.25-6.75j)
+        self.assertEqual(ast.literal_eval('(3+6j)'), 3+6j)
+        self.assertRaises(ValueError, ast.literal_eval, '-6j+3')
+        self.assertRaises(ValueError, ast.literal_eval, '-6j+3j')
+        self.assertRaises(ValueError, ast.literal_eval, '3+-6j')
+        self.assertRaises(ValueError, ast.literal_eval, '3+(0+6j)')
+        self.assertRaises(ValueError, ast.literal_eval, '-(3+6j)')
 
     def test_bad_integer(self):
         # issue13436: Bad error message with invalid numeric values
@@ -1077,11 +1100,11 @@ def test_literal_eval(self):
         ast.copy_location(new_left, binop.left)
         binop.left = new_left
 
-        new_right = ast.Constant(value=20)
+        new_right = ast.Constant(value=20j)
         ast.copy_location(new_right, binop.right)
         binop.right = new_right
 
-        self.assertEqual(ast.literal_eval(binop), 30)
+        self.assertEqual(ast.literal_eval(binop), 10+20j)
 
 
 def main():
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py
index e8eddbedf7e..cb51f8aff29 100644
--- a/Lib/test/test_inspect.py
+++ b/Lib/test/test_inspect.py
@@ -2074,7 +2074,7 @@ def p(name): return signature.parameters[name].default
         self.assertEqual(p('f'), False)
         self.assertEqual(p('local'), 3)
         self.assertEqual(p('sys'), sys.maxsize)
-        self.assertEqual(p('exp'), sys.maxsize - 1)
+        self.assertNotIn('exp', signature.parameters)
 
         test_callable(object)
 
diff --git a/Misc/NEWS.d/next/Library/2017-10-18-17-29-30.bpo-31778.B6vAkP.rst b/Misc/NEWS.d/next/Library/2017-10-18-17-29-30.bpo-31778.B6vAkP.rst
new file mode 100644
index 00000000000..452ad6e4bd2
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2017-10-18-17-29-30.bpo-31778.B6vAkP.rst
@@ -0,0 +1,2 @@
+ast.literal_eval() is now more strict. Addition and subtraction of
+arbitrary numbers no longer allowed.



More information about the Python-checkins mailing list