[Jython-checkins] jython: Improvements to marshal and its error messages, and fix for bjo #2744.

jeff.allen jython-checkins at python.org
Sat Mar 23 04:50:15 EDT 2019


https://hg.python.org/jython/rev/1e7671d4af8d
changeset:   8229:1e7671d4af8d
user:        Jeff Allen <ja.py at farowl.co.uk>
date:        Sat Mar 23 08:12:27 2019 +0000
summary:
  Improvements to marshal and its error messages, and fix for bjo #2744.

Updating test_marshal, in an attempt to find a failing test for #2077, revealed
other divergences unrelated to #2077. This change addresses those, notably
support for objects with the buffer protocol (#2744). Messages in ValueError
exceptions are improved to be like those from CPython.

files:
  Lib/test/test_marshal.py             |    3 -
  NEWS                                 |    1 +
  src/org/python/modules/_marshal.java |  249 ++++++++------
  3 files changed, 140 insertions(+), 113 deletions(-)


diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py
--- a/Lib/test/test_marshal.py
+++ b/Lib/test/test_marshal.py
@@ -136,7 +136,6 @@
             self.assertEqual(type(s), type(new))
         os.unlink(test_support.TESTFN)
 
-    @unittest.skipIf(test_support.is_jython, "FIXME: bjo #2744 buffer not supported")
     def test_buffer(self):
         for s in ["", "André Previn", "abc", " "*10000]:
             with test_support.check_py3k_warnings(("buffer.. not supported",
@@ -252,7 +251,6 @@
         last.append([0])
         self.assertRaises(ValueError, marshal.dumps, head)
 
-    @unittest.skipIf(test_support.is_jython, "FIXME: bjo #2077 ValueError not raised")
     def test_exact_type_match(self):
         # Former bug:
         #   >>> class Int(int): pass
@@ -271,7 +269,6 @@
         testString = 'abc' * size
         marshal.dumps(testString)
 
-    @unittest.skipIf(test_support.is_jython, "FIXME: bjo #2077 ValueError not raised")
     def test_invalid_longs(self):
         # Issue #7019: marshal.loads shouldn't produce unnormalized PyLongs
         invalid_string = 'l\x02\x00\x00\x00\x00\x00\x00\x00'
diff --git a/NEWS b/NEWS
--- a/NEWS
+++ b/NEWS
@@ -5,6 +5,7 @@
 
 Development tip
   Bugs fixed
+    - [ 2744 ] Support buffer type in marshal.dump(s)
     - [ 2077 ] marshal doesn't raise error when fed unmarshalable object
     - [ 2445 ] Eclipse's DelegatingFeatureMap has MRO conflict
     - [ 2732 ] Regression in large module support for pip
diff --git a/src/org/python/modules/_marshal.java b/src/org/python/modules/_marshal.java
--- a/src/org/python/modules/_marshal.java
+++ b/src/org/python/modules/_marshal.java
@@ -1,25 +1,31 @@
 package org.python.modules;
 
-import java.math.BigInteger;
 import org.python.core.BaseSet;
+import org.python.core.BufferProtocol;
 import org.python.core.ClassDictInit;
-import org.python.core.PyObject;
-import org.python.core.PyString;
 import org.python.core.Py;
+import org.python.core.PyBUF;
+import org.python.core.PyBuffer;
 import org.python.core.PyBytecode;
 import org.python.core.PyComplex;
 import org.python.core.PyDictionary;
+import org.python.core.PyException;
 import org.python.core.PyFloat;
 import org.python.core.PyFrozenSet;
 import org.python.core.PyInteger;
 import org.python.core.PyList;
 import org.python.core.PyLong;
+import org.python.core.PyObject;
 import org.python.core.PySet;
+import org.python.core.PyString;
 import org.python.core.PyTuple;
+import org.python.core.PyType;
 import org.python.core.PyUnicode;
 import org.python.core.Traverseproc;
 import org.python.core.Visitproc;
 
+import java.math.BigInteger;
+
 public class _marshal implements ClassDictInit {
 
     public static void classDictInit(PyObject dict) {
@@ -112,13 +118,12 @@
 
         // writes output in 15 bit "digits"
         private void write_long(BigInteger x) {
-            int sign = x.signum();
-            if (sign < 0) {
+            boolean negative = x.signum() < 0;
+            if (negative) {
                 x = x.negate();
             }
-            int num_bits = x.bitLength();
-            int num_digits = num_bits / 15 + (num_bits % 15 == 0 ? 0 : 1);
-            write_int(sign < 0 ? -num_digits : num_digits);
+            int num_digits = (x.bitLength() + 14) / 15;
+            write_int(negative ? -num_digits : num_digits);
             BigInteger mask = BigInteger.valueOf(0x7FFF);
             for (int i = 0; i < num_digits; i++) {
                 write_short(x.and(mask).shortValue());
@@ -149,96 +154,108 @@
                 write_byte(TYPE_FALSE);
             } else if (v == Py.True) {
                 write_byte(TYPE_TRUE);
-            } else if (v instanceof PyInteger) {
-                write_byte(TYPE_INT);
-                write_int(((PyInteger) v).asInt());
-            } else if (v instanceof PyLong) {
-                write_byte(TYPE_LONG);
-                write_long(((PyLong) v).getValue());
-            } else if (v instanceof PyFloat) {
-                if (version == CURRENT_VERSION) {
-                    write_byte(TYPE_BINARY_FLOAT);
-                    write_binary_float((PyFloat) v);
-                } else {
-                    write_byte(TYPE_FLOAT);
-                    write_float((PyFloat) v);
-                }
-            } else if (v instanceof PyComplex) {
-                PyComplex x = (PyComplex) v;
-                if (version == CURRENT_VERSION) {
-                    write_byte(TYPE_BINARY_COMPLEX);
-                    write_binary_float(x.getReal());
-                    write_binary_float(x.getImag());
+            } else {
+                PyType vt = v.getType();
+                if (vt == PyInteger.TYPE) {
+                    write_byte(TYPE_INT);
+                    write_int(((PyInteger) v).asInt());
+                } else if (vt == PyLong.TYPE) {
+                    write_byte(TYPE_LONG);
+                    write_long(((PyLong) v).getValue());
+                } else if (vt == PyFloat.TYPE) {
+                    if (version == CURRENT_VERSION) {
+                        write_byte(TYPE_BINARY_FLOAT);
+                        write_binary_float((PyFloat) v);
+                    } else {
+                        write_byte(TYPE_FLOAT);
+                        write_float((PyFloat) v);
+                    }
+                } else if (vt == PyComplex.TYPE) {
+                    PyComplex x = (PyComplex) v;
+                    if (version == CURRENT_VERSION) {
+                        write_byte(TYPE_BINARY_COMPLEX);
+                        write_binary_float(x.getReal());
+                        write_binary_float(x.getImag());
+                    } else {
+                        write_byte(TYPE_COMPLEX);
+                        write_float(x.getReal());
+                        write_float(x.getImag());
+                    }
+                } else if (vt == PyUnicode.TYPE) {
+                    write_byte(TYPE_UNICODE);
+                    String buffer = ((PyUnicode) v).encode("utf-8").toString();
+                    write_int(buffer.length());
+                    write_string(buffer);
+                } else if (vt == PyString.TYPE) {
+                    // ignore interning
+                    write_byte(TYPE_STRING);
+                    write_int(v.__len__());
+                    write_string(v.toString());
+                } else if (vt == PyTuple.TYPE) {
+                    write_byte(TYPE_TUPLE);
+                    PyTuple t = (PyTuple) v;
+                    int n = t.__len__();
+                    write_int(n);
+                    for (int i = 0; i < n; i++) {
+                        write_object(t.__getitem__(i), depth + 1);
+                    }
+                } else if (vt == PyList.TYPE) {
+                    write_byte(TYPE_LIST);
+                    PyList list = (PyList) v;
+                    int n = list.__len__();
+                    write_int(n);
+                    for (int i = 0; i < n; i++) {
+                        write_object(list.__getitem__(i), depth + 1);
+                    }
+                } else if (vt == PyDictionary.TYPE) {
+                    write_byte(TYPE_DICT);
+                    PyDictionary dict = (PyDictionary) v;
+                    for (PyObject item : dict.iteritems().asIterable()) {
+                        PyTuple pair = (PyTuple) item;
+                        write_object(pair.__getitem__(0), depth + 1);
+                        write_object(pair.__getitem__(1), depth + 1);
+                    }
+                    write_object(null, depth + 1);
+                } else if (vt == PySet.TYPE || vt == PyFrozenSet.TYPE) {
+                    if (vt == PySet.TYPE) {
+                        write_byte(TYPE_SET);
+                    } else {
+                        write_byte(TYPE_FROZENSET);
+                    }
+                    int n = v.__len__();
+                    write_int(n);
+                    BaseSet set = (BaseSet) v;
+                    for (PyObject item : set.asIterable()) {
+                        write_object(item, depth + 1);
+                    }
+                } else if (vt == PyBytecode.TYPE) {
+                    PyBytecode code = (PyBytecode) v;
+                    write_byte(TYPE_CODE);
+                    write_int(code.co_argcount);
+                    write_int(code.co_nlocals);
+                    write_int(code.co_stacksize);
+                    write_int(code.co_flags.toBits());
+                    write_object(Py.newString(new String(code.co_code)), depth + 1);
+                    write_object(new PyTuple(code.co_consts), depth + 1);
+                    write_strings(code.co_names, depth + 1);
+                    write_strings(code.co_varnames, depth + 1);
+                    write_strings(code.co_freevars, depth + 1);
+                    write_strings(code.co_cellvars, depth + 1);
+                    write_object(Py.newString(code.co_name), depth + 1);
+                    write_int(code.co_firstlineno);
+                    write_object(Py.newString(new String(code.co_lnotab)), depth + 1);
                 } else {
-                    write_byte(TYPE_COMPLEX);
-                    write_float(x.getReal());
-                    write_float(x.getImag());
-                }
-            } else if (v instanceof PyUnicode) {
-                write_byte(TYPE_UNICODE);
-                String buffer = ((PyUnicode) v).encode("utf-8").toString();
-                write_int(buffer.length());
-                write_string(buffer);
-            } else if (v instanceof PyString) {
-                // ignore interning
-                write_byte(TYPE_STRING);
-                write_int(v.__len__());
-                write_string(v.toString());
-            } else if (v instanceof PyTuple) {
-                write_byte(TYPE_TUPLE);
-                PyTuple t = (PyTuple) v;
-                int n = t.__len__();
-                write_int(n);
-                for (int i = 0; i < n; i++) {
-                    write_object(t.__getitem__(i), depth + 1);
+                    // Try to get a simple byte-oriented buffer
+                    try (PyBuffer buf = ((BufferProtocol) v).getBuffer(PyBUF.SIMPLE)) {
+                        // ... and treat those bytes as a String
+                        write_byte(TYPE_STRING);
+                        write_int(v.__len__());
+                        write_string(buf.toString());
+                    } catch (ClassCastException | PyException e) {
+                        // Does not implement BufferProtocol (in simple byte form).
+                        throw Py.ValueError("unmarshallable object");
+                    }
                 }
-            } else if (v instanceof PyList) {
-                write_byte(TYPE_LIST);
-                PyList list = (PyList) v;
-                int n = list.__len__();
-                write_int(n);
-                for (int i = 0; i < n; i++) {
-                    write_object(list.__getitem__(i), depth + 1);
-                }
-            } else if (v instanceof PyDictionary) {
-                write_byte(TYPE_DICT);
-                PyDictionary dict = (PyDictionary) v;
-                for (PyObject item : dict.iteritems().asIterable()) {
-                    PyTuple pair = (PyTuple) item;
-                    write_object(pair.__getitem__(0), depth + 1);
-                    write_object(pair.__getitem__(1), depth + 1);
-                }
-                write_object(null, depth + 1);
-            } else if (v instanceof BaseSet) {
-                if (v instanceof PySet) {
-                    write_byte(TYPE_SET);
-                } else {
-                    write_byte(TYPE_FROZENSET);
-                }
-                int n = v.__len__();
-                write_int(n);
-                BaseSet set = (BaseSet) v;
-                for (PyObject item : set.asIterable()) {
-                    write_object(item, depth + 1);
-                }
-            } else if (v instanceof PyBytecode) {
-                PyBytecode code = (PyBytecode) v;
-                write_byte(TYPE_CODE);
-                write_int(code.co_argcount);
-                write_int(code.co_nlocals);
-                write_int(code.co_stacksize);
-                write_int(code.co_flags.toBits());
-                write_object(Py.newString(new String(code.co_code)), depth + 1);
-                write_object(new PyTuple(code.co_consts), depth + 1);
-                write_strings(code.co_names, depth + 1);
-                write_strings(code.co_varnames, depth + 1);
-                write_strings(code.co_freevars, depth + 1);
-                write_strings(code.co_cellvars, depth + 1);
-                write_object(Py.newString(code.co_name), depth + 1);
-                write_int(code.co_firstlineno);
-                write_object(Py.newString(new String(code.co_lnotab)), depth + 1);
-            } else {
-                throw Py.ValueError("unmarshallable object");
             }
             depth--;
         }
@@ -327,21 +344,25 @@
         }
 
         private BigInteger read_long() {
-            int size = read_int();
-            int sign = 1;
-            if (size < 0) {
-                sign = -1;
+            BigInteger result = BigInteger.ZERO;
+            boolean negative = false;
+            int digit = 0, size = read_int();
+            if (size == 0) {
+                return result;
+            } else if (size < 0) {
+                negative = true;
                 size = -size;
             }
-            BigInteger result = BigInteger.ZERO;
             for (int i = 0; i < size; i++) {
-                String digits = String.valueOf(read_short());
-                result = result.or(new BigInteger(digits).shiftLeft(i * 15));
+                if ((digit = read_short()) < 0) {
+                    throw badMarshalData("digit out of range in long");
+                }
+                result = result.or(BigInteger.valueOf(digit).shiftLeft(i * 15));
             }
-            if (sign < 0) {
-                result = result.negate();
+            if (digit == 0) {
+                throw badMarshalData("unnormalized long data");
             }
-            return result;
+            return negative ? result.negate() : result;
         }
 
         private double read_float() {
@@ -356,7 +377,7 @@
         private PyObject read_object_notnull(int depth) {
             PyObject v = read_object(depth);
             if (v == null) {
-                throw Py.ValueError("bad marshal data");
+                throw badMarshalData(null);
             }
             return v;
         }
@@ -451,7 +472,7 @@
                 case TYPE_TUPLE: {
                     int n = read_int();
                     if (n < 0) {
-                        throw Py.ValueError("bad marshal data");
+                        throw badMarshalData(null);
                     }
                     PyObject items[] = new PyObject[n];
                     for (int i = 0; i < n; i++) {
@@ -463,7 +484,7 @@
                 case TYPE_LIST: {
                     int n = read_int();
                     if (n < 0) {
-                        throw Py.ValueError("bad marshal data");
+                        throw badMarshalData(null);
                     }
                     PyObject items[] = new PyObject[n];
                     for (int i = 0; i < n; i++) {
@@ -528,10 +549,18 @@
                 }
 
                 default:
-                    throw Py.ValueError("bad marshal data");
+                    throw badMarshalData("unknown type code");
             }
         }
 
+        /** Helper returning "bad marshal data" or "bad marshal data (<reason>)". */
+        private static PyException badMarshalData(String reason) {
+            StringBuilder msg = (new StringBuilder(60)).append("bad marshal data");
+            if (reason != null) {
+                msg.append(" (").append(reason).append(')');
+            }
+            return Py.ValueError(msg.toString());
+        }
 
         /* Traverseproc implementation */
         @Override

-- 
Repository URL: https://hg.python.org/jython


More information about the Jython-checkins mailing list