From jython-checkins at python.org Mon Dec 1 00:25:21 2014 From: jython-checkins at python.org (jeff.allen) Date: Sun, 30 Nov 2014 23:25:21 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Prevent_non-byte_values_app?= =?utf-8?q?earing_in_str=28=29=2C_fixes_=232037=2E?= Message-ID: <20141130232521.69799.28299@psf.io> https://hg.python.org/jython/rev/f0c63b42e552 changeset: 7424:f0c63b42e552 parent: 7405:7c731ca90075 user: Jeff Allen date: Sun Nov 16 22:50:43 2014 +0000 summary: Prevent non-byte values appearing in str(), fixes #2037. The constructor is augmented with a test to raise IllegalArgumentError if supplied a Java String (etc.) with any code point over 255. Also, the __str__ and __repr__ of proxied Java objects are modified to return unicode objects, with the result that str(o) raises UnicodeEncodeError for Java objects where o.toString() is not pure ascii. files: Lib/test/test_str_jy.py | 11 +- src/org/python/core/Py.java | 9 +- src/org/python/core/PyJavaType.java | 6 +- src/org/python/core/PyString.java | 70 ++++++++- src/org/python/core/PyUnicode.java | 3 +- tests/java/org/python/core/BaseBytesTest.java | 24 +-- 6 files changed, 90 insertions(+), 33 deletions(-) diff --git a/Lib/test/test_str_jy.py b/Lib/test/test_str_jy.py --- a/Lib/test/test_str_jy.py +++ b/Lib/test/test_str_jy.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from test import test_support +import java.lang import unittest class WrappedStrCmpTest(unittest.TestCase): @@ -23,12 +24,18 @@ ABC = Wrapper('ABC') self.assertEquals(1, d[ABC]) -class IntToStrTest(unittest.TestCase): +class StrConstructorTest(unittest.TestCase): def test_int_to_string_format(self): # 0.001 comes out as 0.0010 self.assertEquals(str(0.001), "0.001") + def test_unicode_resistance(self): + # Issue 2037: prevent byte/str elements > 255 + self.assertRaises(UnicodeEncodeError, str, java.lang.String(u"caf\xe9 noir")) + self.assertRaises(UnicodeEncodeError, str, java.lang.String(u"abc\u0111efgh")) + + class StringSlicingTest(unittest.TestCase): def test_out_of_bounds(self): @@ -165,7 +172,7 @@ def test_main(): test_support.run_unittest( WrappedStrCmpTest, - IntToStrTest, + StrConstructorTest, StringSlicingTest, FormatTest, DisplayTest, diff --git a/src/org/python/core/Py.java b/src/org/python/core/Py.java --- a/src/org/python/core/Py.java +++ b/src/org/python/core/Py.java @@ -1652,7 +1652,7 @@ static { for (char j = 0; j < 256; j++) { - letters[j] = new PyString(new Character(j).toString()); + letters[j] = new PyString(j); } } @@ -1667,11 +1667,8 @@ static final PyString makeCharacter(int codepoint, boolean toUnicode) { if (toUnicode) { return new PyUnicode(codepoint); - } else if (codepoint > 65536) { - throw new IllegalArgumentException(String.format("Codepoint > 65536 (%d) requires " - + "toUnicode argument", codepoint)); - } else if (codepoint > 256) { - return new PyString((char)codepoint); + } else if (codepoint > 255) { + throw new IllegalArgumentException("Cannot create PyString with non-byte value"); } return letters[codepoint]; } diff --git a/src/org/python/core/PyJavaType.java b/src/org/python/core/PyJavaType.java --- a/src/org/python/core/PyJavaType.java +++ b/src/org/python/core/PyJavaType.java @@ -611,8 +611,12 @@ addMethod(new PyBuiltinMethodNarrow("__repr__") { @Override public PyObject __call__() { + /* + * java.lang.Object.toString returns Unicode: preserve as a PyUnicode, then let + * the repr() built-in decide how to handle it. (Also applies to __str__.) + */ String toString = self.getJavaProxy().toString(); - return toString == null ? Py.EmptyString : Py.newString(toString); + return toString == null ? Py.EmptyUnicode : Py.newUnicode(toString); } }); addMethod(new PyBuiltinMethodNarrow("__unicode__") { diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -42,10 +42,19 @@ this(TYPE, ""); } + /** + * Fundamental constructor for PyString objects when the client provides a Java + * String, necessitating that we range check the characters. + * + * @param subType the actual type being constructed + * @param string a Java String to be wrapped + */ public PyString(PyType subType, String string) { super(subType); if (string == null) { throw new IllegalArgumentException("Cannot create PyString from null!"); + } else if (!isBytes(string)) { + throw new IllegalArgumentException("Cannot create PyString with non-byte value"); } this.string = string; } @@ -63,6 +72,40 @@ } /** + * Determine whether a string consists entirely of characters in the range 0 to 255. Only such + * characters are allowed in the PyString (str) type, when it is not a + * {@link PyUnicode}. + * + * @return true if and only if every character has a code less than 256 + */ + private static boolean isBytes(String s) { + int k = s.length(); + if (k == 0) { + return true; + } else { + // Bitwise-or the character codes together in order to test once. + char c = 0; + // Blocks of 8 to reduce loop tests + while (k > 8) { + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + } + // Now the rest + while (k > 0) { + c |= s.charAt(--k); + } + // We require there to be no bits set from 0x100 upwards + return c < 0x100; + } + } + + /** * Creates a PyString from an already interned String. Just means it won't be reinterned if used * in a place that requires interned Strings. */ @@ -88,16 +131,25 @@ String[] keywords) { ArgParser ap = new ArgParser("str", args, keywords, new String[] {"object"}, 0); PyObject S = ap.getPyObject(0, null); + // Get the textual representation of the object into str/bytes form + String str; + if (S == null) { + str = ""; + } else { + // Let the object tell us its representation: this may be str or unicode. + S = S.__str__(); + if (S instanceof PyUnicode) { + // Encoding will raise UnicodeEncodeError if not 7-bit clean. + str = codecs.encode((PyUnicode)S, null, null); + } else { + // Must be str/bytes, and should be 8-bit clean already. + str = S.toString(); + } + } if (new_.for_type == subtype) { - if (S == null) { - return new PyString(""); - } - return new PyString(S.__str__().toString()); + return new PyString(str); } else { - if (S == null) { - return new PyStringDerived(subtype, ""); - } - return new PyStringDerived(subtype, S.__str__().toString()); + return new PyStringDerived(subtype, str); } } @@ -4606,7 +4658,7 @@ default: throw Py.ValueError("unsupported format character '" - + codecs.encode(Py.newString(spec.type), null, "replace") + "' (0x" + + codecs.encode(Py.newUnicode(spec.type), null, "replace") + "' (0x" + Integer.toHexString(spec.type) + ") at index " + (index - 1)); } diff --git a/src/org/python/core/PyUnicode.java b/src/org/python/core/PyUnicode.java --- a/src/org/python/core/PyUnicode.java +++ b/src/org/python/core/PyUnicode.java @@ -114,7 +114,8 @@ * @param isBasic true if it is known that only BMP characters are present. */ private PyUnicode(PyType subtype, String string, boolean isBasic) { - super(subtype, string); + super(subtype, ""); + this.string = string; translator = isBasic ? BASIC : this.chooseIndexTranslator(); } diff --git a/tests/java/org/python/core/BaseBytesTest.java b/tests/java/org/python/core/BaseBytesTest.java --- a/tests/java/org/python/core/BaseBytesTest.java +++ b/tests/java/org/python/core/BaseBytesTest.java @@ -304,22 +304,18 @@ // Need interpreter for exceptions to be formed properly interp = new PythonInterpreter(); // A scary set of objects - final PyObject[] brantub = {Py.None, - new PyInteger(-1), - new PyLong(0x80000000L), - new PyString("\u00A0\u0100\u00A2\u00A3\u00A4"), - new PyString("\u00A0\u00A0\u1000\u00A3\u00A4"), - new PyXRange(3, -2, -1), - new PyXRange(250, 257)}; + final PyObject[] brantub = {Py.None, new PyInteger(-1), // + new PyLong(0x80000000L), // + new PyXRange(3, -2, -1), // + new PyXRange(250, 257) // + }; // The PyException types we should obtain final PyObject[] boobyPrize = {Py.TypeError, // None - Py.ValueError, // -1 - Py.OverflowError, // 0x80000000L - Py.ValueError, // \u0100 byte - Py.ValueError, // \u1000 byte - Py.ValueError, // -1 in iterable - Py.ValueError // 256 in iterable - }; + Py.ValueError, // -1 + Py.OverflowError, // 0x80000000L + Py.ValueError, // -1 in iterable + Py.ValueError // 256 in iterable + }; // Work down the lists for (int dip = 0; dip < brantub.length; dip++) { PyObject aRef = boobyPrize[dip]; -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Mon Dec 1 00:25:22 2014 From: jython-checkins at python.org (jeff.allen) Date: Sun, 30 Nov 2014 23:25:22 +0000 Subject: [Jython-checkins] =?utf-8?q?jython_=28merge_default_-=3E_default?= =?utf-8?q?=29=3A_Merge_str_bytes_check_to_trunk?= Message-ID: <20141130232522.55123.81988@psf.io> https://hg.python.org/jython/rev/849ec9c291db changeset: 7425:849ec9c291db parent: 7423:6aa434d5dc01 parent: 7424:f0c63b42e552 user: Jeff Allen date: Sun Nov 30 23:25:03 2014 +0000 summary: Merge str bytes check to trunk files: Lib/test/test_str_jy.py | 11 +- src/org/python/core/Py.java | 9 +- src/org/python/core/PyJavaType.java | 6 +- src/org/python/core/PyString.java | 70 ++++++++- src/org/python/core/PyUnicode.java | 3 +- tests/java/org/python/core/BaseBytesTest.java | 24 +-- 6 files changed, 90 insertions(+), 33 deletions(-) diff --git a/Lib/test/test_str_jy.py b/Lib/test/test_str_jy.py --- a/Lib/test/test_str_jy.py +++ b/Lib/test/test_str_jy.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from test import test_support +import java.lang import unittest class WrappedStrCmpTest(unittest.TestCase): @@ -23,12 +24,18 @@ ABC = Wrapper('ABC') self.assertEquals(1, d[ABC]) -class IntToStrTest(unittest.TestCase): +class StrConstructorTest(unittest.TestCase): def test_int_to_string_format(self): # 0.001 comes out as 0.0010 self.assertEquals(str(0.001), "0.001") + def test_unicode_resistance(self): + # Issue 2037: prevent byte/str elements > 255 + self.assertRaises(UnicodeEncodeError, str, java.lang.String(u"caf\xe9 noir")) + self.assertRaises(UnicodeEncodeError, str, java.lang.String(u"abc\u0111efgh")) + + class StringSlicingTest(unittest.TestCase): def test_out_of_bounds(self): @@ -165,7 +172,7 @@ def test_main(): test_support.run_unittest( WrappedStrCmpTest, - IntToStrTest, + StrConstructorTest, StringSlicingTest, FormatTest, DisplayTest, diff --git a/src/org/python/core/Py.java b/src/org/python/core/Py.java --- a/src/org/python/core/Py.java +++ b/src/org/python/core/Py.java @@ -1652,7 +1652,7 @@ static { for (char j = 0; j < 256; j++) { - letters[j] = new PyString(new Character(j).toString()); + letters[j] = new PyString(j); } } @@ -1667,11 +1667,8 @@ static final PyString makeCharacter(int codepoint, boolean toUnicode) { if (toUnicode) { return new PyUnicode(codepoint); - } else if (codepoint > 65536) { - throw new IllegalArgumentException(String.format("Codepoint > 65536 (%d) requires " - + "toUnicode argument", codepoint)); - } else if (codepoint > 256) { - return new PyString((char)codepoint); + } else if (codepoint > 255) { + throw new IllegalArgumentException("Cannot create PyString with non-byte value"); } return letters[codepoint]; } diff --git a/src/org/python/core/PyJavaType.java b/src/org/python/core/PyJavaType.java --- a/src/org/python/core/PyJavaType.java +++ b/src/org/python/core/PyJavaType.java @@ -611,8 +611,12 @@ addMethod(new PyBuiltinMethodNarrow("__repr__") { @Override public PyObject __call__() { + /* + * java.lang.Object.toString returns Unicode: preserve as a PyUnicode, then let + * the repr() built-in decide how to handle it. (Also applies to __str__.) + */ String toString = self.getJavaProxy().toString(); - return toString == null ? Py.EmptyString : Py.newString(toString); + return toString == null ? Py.EmptyUnicode : Py.newUnicode(toString); } }); addMethod(new PyBuiltinMethodNarrow("__unicode__") { diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -42,10 +42,19 @@ this(TYPE, ""); } + /** + * Fundamental constructor for PyString objects when the client provides a Java + * String, necessitating that we range check the characters. + * + * @param subType the actual type being constructed + * @param string a Java String to be wrapped + */ public PyString(PyType subType, String string) { super(subType); if (string == null) { throw new IllegalArgumentException("Cannot create PyString from null!"); + } else if (!isBytes(string)) { + throw new IllegalArgumentException("Cannot create PyString with non-byte value"); } this.string = string; } @@ -63,6 +72,40 @@ } /** + * Determine whether a string consists entirely of characters in the range 0 to 255. Only such + * characters are allowed in the PyString (str) type, when it is not a + * {@link PyUnicode}. + * + * @return true if and only if every character has a code less than 256 + */ + private static boolean isBytes(String s) { + int k = s.length(); + if (k == 0) { + return true; + } else { + // Bitwise-or the character codes together in order to test once. + char c = 0; + // Blocks of 8 to reduce loop tests + while (k > 8) { + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + c |= s.charAt(--k); + } + // Now the rest + while (k > 0) { + c |= s.charAt(--k); + } + // We require there to be no bits set from 0x100 upwards + return c < 0x100; + } + } + + /** * Creates a PyString from an already interned String. Just means it won't be reinterned if used * in a place that requires interned Strings. */ @@ -88,16 +131,25 @@ String[] keywords) { ArgParser ap = new ArgParser("str", args, keywords, new String[] {"object"}, 0); PyObject S = ap.getPyObject(0, null); + // Get the textual representation of the object into str/bytes form + String str; + if (S == null) { + str = ""; + } else { + // Let the object tell us its representation: this may be str or unicode. + S = S.__str__(); + if (S instanceof PyUnicode) { + // Encoding will raise UnicodeEncodeError if not 7-bit clean. + str = codecs.encode((PyUnicode)S, null, null); + } else { + // Must be str/bytes, and should be 8-bit clean already. + str = S.toString(); + } + } if (new_.for_type == subtype) { - if (S == null) { - return new PyString(""); - } - return new PyString(S.__str__().toString()); + return new PyString(str); } else { - if (S == null) { - return new PyStringDerived(subtype, ""); - } - return new PyStringDerived(subtype, S.__str__().toString()); + return new PyStringDerived(subtype, str); } } @@ -4606,7 +4658,7 @@ default: throw Py.ValueError("unsupported format character '" - + codecs.encode(Py.newString(spec.type), null, "replace") + "' (0x" + + codecs.encode(Py.newUnicode(spec.type), null, "replace") + "' (0x" + Integer.toHexString(spec.type) + ") at index " + (index - 1)); } diff --git a/src/org/python/core/PyUnicode.java b/src/org/python/core/PyUnicode.java --- a/src/org/python/core/PyUnicode.java +++ b/src/org/python/core/PyUnicode.java @@ -114,7 +114,8 @@ * @param isBasic true if it is known that only BMP characters are present. */ private PyUnicode(PyType subtype, String string, boolean isBasic) { - super(subtype, string); + super(subtype, ""); + this.string = string; translator = isBasic ? BASIC : this.chooseIndexTranslator(); } diff --git a/tests/java/org/python/core/BaseBytesTest.java b/tests/java/org/python/core/BaseBytesTest.java --- a/tests/java/org/python/core/BaseBytesTest.java +++ b/tests/java/org/python/core/BaseBytesTest.java @@ -304,22 +304,18 @@ // Need interpreter for exceptions to be formed properly interp = new PythonInterpreter(); // A scary set of objects - final PyObject[] brantub = {Py.None, - new PyInteger(-1), - new PyLong(0x80000000L), - new PyString("\u00A0\u0100\u00A2\u00A3\u00A4"), - new PyString("\u00A0\u00A0\u1000\u00A3\u00A4"), - new PyXRange(3, -2, -1), - new PyXRange(250, 257)}; + final PyObject[] brantub = {Py.None, new PyInteger(-1), // + new PyLong(0x80000000L), // + new PyXRange(3, -2, -1), // + new PyXRange(250, 257) // + }; // The PyException types we should obtain final PyObject[] boobyPrize = {Py.TypeError, // None - Py.ValueError, // -1 - Py.OverflowError, // 0x80000000L - Py.ValueError, // \u0100 byte - Py.ValueError, // \u1000 byte - Py.ValueError, // -1 in iterable - Py.ValueError // 256 in iterable - }; + Py.ValueError, // -1 + Py.OverflowError, // 0x80000000L + Py.ValueError, // -1 in iterable + Py.ValueError // 256 in iterable + }; // Work down the lists for (int dip = 0; dip < brantub.length; dip++) { PyObject aRef = boobyPrize[dip]; -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 2 23:11:53 2014 From: jython-checkins at python.org (jeff.allen) Date: Tue, 02 Dec 2014 22:11:53 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Avoid_PyString_range-test_i?= =?utf-8?q?n_operations_involving_sub-strings=2E?= Message-ID: <20141202221125.105531.29714@psf.io> https://hg.python.org/jython/rev/720e34a4d5be changeset: 7428:720e34a4d5be user: Jeff Allen date: Tue Dec 02 20:44:45 2014 +0000 summary: Avoid PyString range-test in operations involving sub-strings. Uses the private "no-check"constructor for these common operations: intends to restore performance in most applications. files: NEWS | 4 ++++ src/org/python/core/Py.java | 14 ++++++++++---- src/org/python/core/PyString.java | 12 +++++++----- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/NEWS b/NEWS --- a/NEWS +++ b/NEWS @@ -1,5 +1,9 @@ Jython NEWS +Jython 2.7b4 + Bugs Fixed + - [ 2037 ] Byte-string containing elements greater than 255 + Jython 2.7b3 Bugs Fixed - [ 2225 ] Jython+django-jython - no module named site diff --git a/src/org/python/core/Py.java b/src/org/python/core/Py.java --- a/src/org/python/core/Py.java +++ b/src/org/python/core/Py.java @@ -633,7 +633,7 @@ } return new PyStringMap(); } - + public static PyUnicode newUnicode(char c) { return (PyUnicode) makeCharacter(c, true); } @@ -1661,14 +1661,20 @@ } public static final PyString makeCharacter(char c) { - return makeCharacter(c, false); + if (c <= 255) { + return letters[c]; + } else { + // This will throw IllegalArgumentException since non-byte value + return new PyString(c); + } } static final PyString makeCharacter(int codepoint, boolean toUnicode) { if (toUnicode) { return new PyUnicode(codepoint); - } else if (codepoint > 255) { - throw new IllegalArgumentException("Cannot create PyString with non-byte value"); + } else if (codepoint < 0 || codepoint > 255) { + // This will throw IllegalArgumentException since non-byte value + return new PyString('\uffff'); } return letters[codepoint]; } diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -727,7 +727,8 @@ @Override protected PyObject pyget(int i) { - return Py.newString(getString().charAt(i)); + // Method is overridden in PyUnicode, so definitely a PyString + return Py.makeCharacter(string.charAt(i)); } @Override @@ -1263,7 +1264,7 @@ // It ought to be None, null, some kind of bytes with the buffer API. String stripChars = asStringNullOrError(chars, "strip"); // Strip specified characters or whitespace if stripChars == null - return new PyString(_strip(stripChars)); + return new PyString(_strip(stripChars), true); } } @@ -1433,7 +1434,7 @@ // It ought to be None, null, some kind of bytes with the buffer API. String stripChars = asStringNullOrError(chars, "lstrip"); // Strip specified characters or whitespace if stripChars == null - return new PyString(_lstrip(stripChars)); + return new PyString(_lstrip(stripChars), true); } } @@ -1522,7 +1523,7 @@ // It ought to be None, null, some kind of bytes with the buffer API. String stripChars = asStringNullOrError(chars, "rstrip"); // Strip specified characters or whitespace if stripChars == null - return new PyString(_rstrip(stripChars)); + return new PyString(_rstrip(stripChars), true); } } @@ -2231,7 +2232,8 @@ * @return new object. */ protected PyString fromSubstring(int begin, int end) { - return createInstance(getString().substring(begin, end), true); + // Method is overridden in PyUnicode, so definitely a PyString + return new PyString(getString().substring(begin, end), true); } /** -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 2 23:11:53 2014 From: jython-checkins at python.org (jeff.allen) Date: Tue, 02 Dec 2014 22:11:53 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Avoid_PyString_range-test_i?= =?utf-8?q?n_concatenation_and_join=2E?= Message-ID: <20141202221125.96690.26397@psf.io> https://hg.python.org/jython/rev/521823de34a5 changeset: 7427:521823de34a5 user: Jeff Allen date: Mon Dec 01 23:53:59 2014 +0000 summary: Avoid PyString range-test in concatenation and join. Adds private constructor for use when we can guarantee bytes. Apply in __add__ and join. files: src/org/python/core/PyString.java | 57 +++++++++++++----- 1 files changed, 41 insertions(+), 16 deletions(-) diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -39,7 +39,7 @@ // for PyJavaClass.init() public PyString() { - this(TYPE, ""); + this("", true); } /** @@ -52,7 +52,7 @@ public PyString(PyType subType, String string) { super(subType); if (string == null) { - throw new IllegalArgumentException("Cannot create PyString from null!"); + throw new IllegalArgumentException("Cannot create PyString from null"); } else if (!isBytes(string)) { throw new IllegalArgumentException("Cannot create PyString with non-byte value"); } @@ -72,6 +72,23 @@ } /** + * Local-use constructor in which the client is allowed to guarantee that the + * String argument contains only characters in the byte range. We do not then + * range-check the characters. + * + * @param string a Java String to be wrapped (not null) + * @param isBytes true if the client guarantees we are dealing with bytes + */ + private PyString(String string, boolean isBytes) { + super(TYPE); + if (isBytes || isBytes(string)) { + this.string = string; + } else { + throw new IllegalArgumentException("Cannot create PyString with non-byte value"); + } + } + + /** * Determine whether a string consists entirely of characters in the range 0 to 255. Only such * characters are allowed in the PyString (str) type, when it is not a * {@link PyUnicode}. @@ -228,7 +245,7 @@ if (getClass() == PyString.class) { return this; } - return new PyString(getString()); + return new PyString(getString(), true); } @Override @@ -785,6 +802,18 @@ * not a unicode. * * @param obj to coerce to a String + * @return coerced value or null if it can't be (including unicode) + */ + private static String asStringOrNull(PyObject obj) { + return (obj instanceof PyUnicode) ? null : asUTF16StringOrNull(obj); + } + + /** + * Return a String equivalent to the argument. This is a helper function to those methods that + * accept any byte array type (any object that supports a one-dimensional byte buffer), but + * not a unicode. + * + * @param obj to coerce to a String * @return coerced value * @throws PyException if the coercion fails (including unicode) */ @@ -917,21 +946,17 @@ @ExposedMethod(type = MethodType.BINARY, doc = BuiltinDocs.str___add___doc) final PyObject str___add__(PyObject other) { - - if (other instanceof PyUnicode) { + // Expect other to be some kind of byte-like object. + String otherStr = asStringOrNull(other); + if (otherStr != null) { + // Yes it is: concatenate as strings, which are guaranteed byte-like. + return new PyString(getString().concat(otherStr), true); + } else if (other instanceof PyUnicode) { // Convert self to PyUnicode and escalate the problem return decode().__add__(other); - } else { - // Some kind of object with the buffer API - String otherStr = asUTF16StringOrNull(other); - if (otherStr == null) { - // Allow PyObject._basic_add to pick up the pieces or raise informative error - return null; - } else { - // Concatenate as strings - return new PyString(getString().concat(otherStr)); - } + // Allow PyObject._basic_add to pick up the pieces or raise informative error + return null; } } @@ -3161,7 +3186,7 @@ } buf.append(((PyString)item).getString()); } - return new PyString(buf.toString()); + return new PyString(buf.toString(), true); // Guaranteed to be byte-like } final PyUnicode unicodeJoin(PyObject obj) { -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 2 23:11:53 2014 From: jython-checkins at python.org (jeff.allen) Date: Tue, 02 Dec 2014 22:11:53 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Clarify_in_PyString_when_we?= =?utf-8?q?_are_dealing_with_bytes_or_UTF-16=2E?= Message-ID: <20141202221122.90395.16697@psf.io> https://hg.python.org/jython/rev/731aa7c968b4 changeset: 7426:731aa7c968b4 user: Jeff Allen date: Mon Dec 01 23:04:52 2014 +0000 summary: Clarify in PyString when we are dealing with bytes or UTF-16. Minor change for clarity, preparatory to avoiding byte-checking (see issue #2037) when data can be guaranteed to be bytes. files: src/org/python/core/PyString.java | 49 ++++++++++-------- 1 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -758,12 +758,14 @@ /** * Return a String equivalent to the argument. This is a helper function to those methods that - * accept any byte array type (any object that supports a one-dimensional byte buffer). + * accept any byte array type (any object that supports a one-dimensional byte buffer), or + * accept a unicode argument which they interpret from its UTF-16 encoded form (the + * internal representation returned by {@link PyUnicode#getString()}). * * @param obj to coerce to a String * @return coerced value or null if it can't be */ - private static String asStringOrNull(PyObject obj) { + private static String asUTF16StringOrNull(PyObject obj) { if (obj instanceof PyString) { // str or unicode object: go directly to the String return ((PyString)obj).getString(); @@ -779,14 +781,15 @@ /** * Return a String equivalent to the argument. This is a helper function to those methods that - * accept any byte array type (any object that supports a one-dimensional byte buffer). + * accept any byte array type (any object that supports a one-dimensional byte buffer), but + * not a unicode. * * @param obj to coerce to a String * @return coerced value - * @throws PyException if the coercion fails + * @throws PyException if the coercion fails (including unicode) */ private static String asStringOrError(PyObject obj) throws PyException { - String ret = asStringOrNull(obj); + String ret = (obj instanceof PyUnicode) ? null : asUTF16StringOrNull(obj); if (ret != null) { return ret; } else { @@ -796,22 +799,23 @@ /** * Return a String equivalent to the argument according to the calling conventions of methods - * that accept anything bearing the buffer interface as a byte string, but also - * PyNone. (Or the argument may be omitted, showing up here as null.) These include - * the strip and split methods of str, where a null - * indicates that the criterion is whitespace, and str.translate. + * that accept as a byte string anything bearing the buffer interface, or accept + * PyNone, but not a unicode. (Or the argument may be omitted, + * showing up here as null.) These include the strip and split methods + * of str, where a null indicates that the criterion is whitespace, and + * str.translate. * * @param obj to coerce to a String or null * @param name of method * @return coerced value or null - * @throws PyException if the coercion fails + * @throws PyException if the coercion fails (including unicode) */ private static String asStringNullOrError(PyObject obj, String name) throws PyException { if (obj == null || obj == Py.None) { return null; } else { - String ret = asStringOrNull(obj); + String ret = (obj instanceof PyUnicode) ? null : asUTF16StringOrNull(obj); if (ret != null) { return ret; } else if (name == null) { @@ -826,18 +830,17 @@ /** * Return a String equivalent to the argument according to the calling conventions of the - * certain methods of str. Those methods accept anything bearing the buffer - * interface as a byte string, or accept a unicode argument for which they accept responsibility - * to interpret from its UTF16 encoded form (the internal representation returned by - * {@link PyUnicode#getString()}). + * certain methods of str. Those methods accept as a byte string anything bearing + * the buffer interface, or accept a unicode argument which they interpret from its + * UTF-16 encoded form (the internal representation returned by {@link PyUnicode#getString()}). * * @param obj to coerce to a String * @return coerced value * @throws PyException if the coercion fails */ - private static String asBMPStringOrError(PyObject obj) { + private static String asUTF16StringOrError(PyObject obj) { // PyUnicode accepted here. Care required in the client if obj is not basic plane. - String ret = asStringOrNull(obj); + String ret = asUTF16StringOrNull(obj); if (ret != null) { return ret; } else { @@ -852,7 +855,7 @@ @ExposedMethod(doc = BuiltinDocs.str___contains___doc) final boolean str___contains__(PyObject o) { - String other = asStringOrError(o); + String other = asUTF16StringOrError(o); return getString().indexOf(other) >= 0; } @@ -921,7 +924,7 @@ } else { // Some kind of object with the buffer API - String otherStr = asStringOrNull(other); + String otherStr = asUTF16StringOrNull(other); if (otherStr == null) { // Allow PyObject._basic_add to pick up the pieces or raise informative error return null; @@ -3278,7 +3281,7 @@ if (!(prefix instanceof PyTuple)) { // It ought to be PyUnicode or some kind of bytes with the buffer API. - String s = asBMPStringOrError(prefix); + String s = asUTF16StringOrError(prefix); // If s is non-BMP, and this is a PyString (bytes), result will correctly be false. return sliceLen >= s.length() && getString().startsWith(s, start); @@ -3286,7 +3289,7 @@ // Loop will return true if this slice starts with any prefix in the tuple for (PyObject prefixObj : ((PyTuple)prefix).getArray()) { // It ought to be PyUnicode or some kind of bytes with the buffer API. - String s = asBMPStringOrError(prefixObj); + String s = asUTF16StringOrError(prefixObj); // If s is non-BMP, and this is a PyString (bytes), result will correctly be false. if (sliceLen >= s.length() && getString().startsWith(s, start)) { return true; @@ -3349,7 +3352,7 @@ if (!(suffix instanceof PyTuple)) { // It ought to be PyUnicode or some kind of bytes with the buffer API. - String s = asBMPStringOrError(suffix); + String s = asUTF16StringOrError(suffix); // If s is non-BMP, and this is a PyString (bytes), result will correctly be false. return substr.endsWith(s); @@ -3357,7 +3360,7 @@ // Loop will return true if this slice ends with any suffix in the tuple for (PyObject suffixObj : ((PyTuple)suffix).getArray()) { // It ought to be PyUnicode or some kind of bytes with the buffer API. - String s = asBMPStringOrError(suffixObj); + String s = asUTF16StringOrError(suffixObj); // If s is non-BMP, and this is a PyString (bytes), result will correctly be false. if (substr.endsWith(s)) { return true; -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 3 01:05:16 2014 From: jython-checkins at python.org (jim.baker) Date: Wed, 03 Dec 2014 00:05:16 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Support_noinherit_directive?= =?utf-8?q?_for_generating_derived_classes=2E?= Message-ID: <20141203000508.105501.98690@psf.io> https://hg.python.org/jython/rev/4023d089f0ae changeset: 7429:4023d089f0ae user: Jim Baker date: Tue Dec 02 17:05:02 2014 -0700 summary: Support noinherit directive for generating derived classes. Adding constructors is not possible with the rest directive, since this is inheritable, as can be seen with defaultdict.derived inheriting from dict.derived. So this change adds a similar directive, noinherit, which otherwise acts like the rest directive. Additionally, the import directive no longer inherits, given that it's also used for fixups of the generated derived class. files: src/org/python/core/PyDictionaryDerived.java | 18 +++--- src/templates/dict.derived | 2 +- src/templates/gderived.py | 30 +++++++-- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/src/org/python/core/PyDictionaryDerived.java b/src/org/python/core/PyDictionaryDerived.java --- a/src/org/python/core/PyDictionaryDerived.java +++ b/src/org/python/core/PyDictionaryDerived.java @@ -67,6 +67,15 @@ } } + public PyDictionaryDerived(PyType subtype,ConcurrentMap backingMap,boolean useBackingMap) { + super(subtype,backingMap,useBackingMap); + slots=new PyObject[subtype.getNumSlots()]; + dict=subtype.instDict(); + if (subtype.needsFinalizer()) { + finalizeTrigger=FinalizeTrigger.makeTrigger(this); + } + } + public PyString __str__() { PyType self_type=getType(); PyObject impl=self_type.lookup("__str__"); @@ -1140,15 +1149,6 @@ return super.__coerce_ex__(o); } - public PyDictionaryDerived(PyType subtype,ConcurrentMap backingMap,boolean useBackingMap) { - super(subtype,backingMap,useBackingMap); - slots=new PyObject[subtype.getNumSlots()]; - dict=subtype.instDict(); - if (subtype.needsFinalizer()) { - finalizeTrigger=FinalizeTrigger.makeTrigger(this); - } - } - public String toString() { PyType self_type=getType(); PyObject impl=self_type.lookup("__repr__"); diff --git a/src/templates/dict.derived b/src/templates/dict.derived --- a/src/templates/dict.derived +++ b/src/templates/dict.derived @@ -3,7 +3,7 @@ ctr: incl: object import: java.util.concurrent.ConcurrentMap -rest: +noinherit: public PyDictionaryDerived(PyType subtype, ConcurrentMap backingMap, boolean useBackingMap) { super(subtype, backingMap, useBackingMap); slots=new PyObject[subtype.getNumSlots()]; diff --git a/src/templates/gderived.py b/src/templates/gderived.py --- a/src/templates/gderived.py +++ b/src/templates/gderived.py @@ -25,7 +25,6 @@ modif_re = re.compile(r"(?:\((\w+)\))?(\w+)") -added_imports = [] # os.path.samefile unavailable on Windows before Python v3.2 if hasattr(os.path, "samefile"): @@ -41,6 +40,7 @@ priority_order = ['require', 'define', 'base_class', 'want_dict', 'ctr', + 'noinherit', 'incl', 'unary1', 'binary', 'ibinary', @@ -68,6 +68,7 @@ self.want_dict = None self.no_toString = False self.ctr_done = 0 + self.added_imports = [] def debug(self, bindings): for name, val in bindings.items(): @@ -86,8 +87,7 @@ return self.auxiliary[name] def dire_import(self, name, parm, body): - global added_imports - added_imports = [x.strip() for x in parm.split(",")] + self.added_imports = [x.strip() for x in parm.split(",")] def dire_require(self, name, parm, body): if body is not None: @@ -134,7 +134,14 @@ def dire_incl(self, name, parm, body): if body is not None: self.invalid(name, 'non-empty body') - directives.execute(directives.load(parm.strip()+'.derived'),self) + + def load(): + for d in directives.load(parm.strip()+'.derived'): + if d.name not in ('noinherit', 'import'): + yield d + + included_directives = list(load()) + directives.execute(included_directives, self) def dire_ctr(self, name, parm, body): if self.ctr_done: @@ -162,6 +169,13 @@ pair = self.get_aux('pair') self.decls = pair.tbind({'trailer': self.decls, 'last': templ}) + def dire_noinherit(self, name, param, body): + if param: + self.invalid(name, 'non-empty parm') + if body is None: + return + self.add_decl(JavaTemplate(body, start='ClassBodyDeclarations')) + def dire_unary1(self, name, parm, body): if body is not None: self.invalid(name, 'non-empty body') @@ -224,7 +238,7 @@ directives.execute(directives.load(fn), gen) result = gen.generate() result = hack_derived_header(outfile, result) - result = add_imports(outfile, result) + result = add_imports(gen, outfile, result) print >> open(outfile, 'w'), result #gen.debug() @@ -253,8 +267,8 @@ return '\n'.join(result) -def add_imports(fn, result): - if not added_imports: +def add_imports(gen, fn, result): + if not gen.added_imports: return result print 'Adding imports for: %s' % fn result = result.splitlines() @@ -264,7 +278,7 @@ for line in result: if not added and line.startswith("import "): added = True - for addition in added_imports: + for addition in gen.added_imports: yield "import %s;" % (addition,) yield line -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 3 01:22:45 2014 From: jython-checkins at python.org (jim.baker) Date: Wed, 03 Dec 2014 00:22:45 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Restore_getpass=2Egetpass_t?= =?utf-8?q?urning_off_echo_in_JLine?= Message-ID: <20141203002230.4189.1469@psf.io> https://hg.python.org/jython/rev/5c36bda58baa changeset: 7430:5c36bda58baa user: Jim Baker date: Tue Dec 02 17:22:23 2014 -0700 summary: Restore getpass.getpass turning off echo in JLine Approximately around r7116 with the console refactoring the attribute for getting the console from sys was changed to sys._jy_console, but not updated in the getpass module. Unfortunately we do not at this time have testing for such interactive usage, so this was just reported via email on http://sourceforge.net/p/jython/mailman/message/33103064/ files: Lib/getpass.py | 2 +- 1 files changed, 1 insertions(+), 1 deletions(-) diff --git a/Lib/getpass.py b/Lib/getpass.py --- a/Lib/getpass.py +++ b/Lib/getpass.py @@ -29,7 +29,7 @@ stream = sys.stdout try: - terminal = sys._jy_interpreter.reader.terminal + terminal = sys._jy_console.reader.terminal except: return default_getpass(prompt) -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 3 01:59:16 2014 From: jython-checkins at python.org (jim.baker) Date: Wed, 03 Dec 2014 00:59:16 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Turn_off_Jython-specific_ge?= =?utf-8?q?tpass=2Egetpass?= Message-ID: <20141203005914.90413.41395@psf.io> https://hg.python.org/jython/rev/dae749f895df changeset: 7431:dae749f895df user: Jim Baker date: Tue Dec 02 17:59:09 2014 -0700 summary: Turn off Jython-specific getpass.getpass Disabling echoing does not reliably work now. Most likely we will have to upgrade to JLine2 for this functionality to be restored. files: Lib/getpass.py | 5 ++--- 1 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Lib/getpass.py b/Lib/getpass.py --- a/Lib/getpass.py +++ b/Lib/getpass.py @@ -27,7 +27,6 @@ """ if stream is None: stream = sys.stdout - try: terminal = sys._jy_console.reader.terminal except: @@ -149,8 +148,8 @@ from EasyDialogs import AskPassword except ImportError: if os.name == 'java': - getpass = jython_getpass - else: + # disable this option for now, this does not reliably work + # getpass = jython_getpass getpass = default_getpass else: getpass = AskPassword -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 7 00:47:34 2014 From: jython-checkins at python.org (jeff.allen) Date: Sat, 06 Dec 2014 23:47:34 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Correct_handling_of_str=2Ef?= =?utf-8?q?ormat=28unicode=29?= Message-ID: <20141206234731.21260.35965@psf.io> https://hg.python.org/jython/rev/7e2e9537565f changeset: 7432:7e2e9537565f user: Jeff Allen date: Sat Dec 06 23:36:25 2014 +0000 summary: Correct handling of str.format(unicode) Compatibly with CPython, we narrow implicitly to a str, using the default enbcoding (which will raise UnicodeEncodeError for characters >127. files: Lib/test/test_format_jy.py | 7 +++++++ src/org/python/core/PyString.java | 6 ++++++ 2 files changed, 13 insertions(+), 0 deletions(-) diff --git a/Lib/test/test_format_jy.py b/Lib/test/test_format_jy.py --- a/Lib/test/test_format_jy.py +++ b/Lib/test/test_format_jy.py @@ -55,6 +55,13 @@ class FormatMisc(unittest.TestCase): # Odd tests Jython used to fail + def test_str_format_unicode(self): + # Check unicode is down-converted to str silently if possible + self.assertEqual("full half hour", "full {:s} hour".format(u"half")) + self.assertEqual("full \xbd hour", "full {:s} hour".format("\xbd")) + self.assertRaises(UnicodeEncodeError, "full {:s} hour".format, u"\xbd") + self.assertEqual(u"full \xbd hour", u"full {:s} hour".format(u"\xbd")) + def test_mixtures(self) : # Check formatting to a common buffer in PyString result = 'The cube of 0.5 -0.866j is -1 to 0.01%.' diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -3965,6 +3965,12 @@ throw Py.ValueError("Unknown conversion specifier " + chunk.conversion); } + // Check for "{}".format(u"abc") + if (fieldObj instanceof PyUnicode && !(this instanceof PyUnicode)) { + // Down-convert to PyString, at the risk of raising UnicodeEncodingError + fieldObj = ((PyUnicode)fieldObj).__str__(); + } + // The format_spec may be simple, or contained nested replacement fields. String formatSpec = chunk.formatSpec; if (chunk.formatSpecNeedsExpanding) { -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 7 06:50:41 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 07 Dec 2014 05:50:41 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Port_=5Fjson=2Ec_to_Java_to?= =?utf-8?q?_speed_up_JSON_encoding/decoding=2E?= Message-ID: <20141207055040.76716.35573@psf.io> https://hg.python.org/jython/rev/5f4c860fb58f changeset: 7433:5f4c860fb58f user: Jim Baker date: Sat Dec 06 22:49:16 2014 -0700 summary: Port _json.c to Java to speed up JSON encoding/decoding. files: Lib/json/tests/test_recursion.py | 112 ++ src/org/python/core/PyString.java | 4 + src/org/python/core/PyUnicode.java | 4 + src/org/python/modules/Setup.java | 1 + src/org/python/modules/_json/Encoder.java | 234 +++++ src/org/python/modules/_json/Scanner.java | 328 +++++++ src/org/python/modules/_json/_json.java | 422 ++++++++++ 7 files changed, 1105 insertions(+), 0 deletions(-) diff --git a/Lib/json/tests/test_recursion.py b/Lib/json/tests/test_recursion.py new file mode 100644 --- /dev/null +++ b/Lib/json/tests/test_recursion.py @@ -0,0 +1,112 @@ +from json.tests import PyTest, CTest + + +class JSONTestObject: + pass + + +class TestRecursion(object): + def test_listrecursion(self): + x = [] + x.append(x) + try: + self.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on list recursion") + x = [] + y = [x] + x.append(y) + try: + self.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on alternating list recursion") + y = [] + x = [y, y] + # ensure that the marker is cleared + self.dumps(x) + + def test_dictrecursion(self): + x = {} + x["test"] = x + try: + self.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on dict recursion") + x = {} + y = {"a": x, "b": x} + # ensure that the marker is cleared + self.dumps(x) + + def test_defaultrecursion(self): + class RecursiveJSONEncoder(self.json.JSONEncoder): + recurse = False + def default(self, o): + if o is JSONTestObject: + if self.recurse: + return [JSONTestObject] + else: + return 'JSONTestObject' + return pyjson.JSONEncoder.default(o) + + enc = RecursiveJSONEncoder() + self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"') + enc.recurse = True + try: + enc.encode(JSONTestObject) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on default recursion") + + + def test_highly_nested_objects_decoding(self): + # test that loading highly-nested objects doesn't segfault when C + # accelerations are used. See #12017 + # str + with self.assertRaises(RuntimeError): + self.loads('{"a":' * 100000 + '1' + '}' * 100000) + with self.assertRaises(RuntimeError): + self.loads('{"a":' * 100000 + '[1]' + '}' * 100000) + with self.assertRaises(RuntimeError): + self.loads('[' * 100000 + '1' + ']' * 100000) + # unicode + with self.assertRaises(RuntimeError): + self.loads(u'{"a":' * 100000 + u'1' + u'}' * 100000) + with self.assertRaises(RuntimeError): + self.loads(u'{"a":' * 100000 + u'[1]' + u'}' * 100000) + with self.assertRaises(RuntimeError): + self.loads(u'[' * 100000 + u'1' + u']' * 100000) + + def test_highly_nested_objects_encoding(self): + # See #12051 + l, d = [], {} + for x in xrange(100000): + l, d = [l], {'k':d} + with self.assertRaises(RuntimeError): + self.dumps(l) + with self.assertRaises(RuntimeError): + self.dumps(d) + + def test_endless_recursion(self): + # See #12051 + class EndlessJSONEncoder(self.json.JSONEncoder): + def default(self, o): + """If check_circular is False, this will keep adding another list.""" + return [o] + + # NB: Jython interacts with overflows differently than CPython; + # given that the default function normally raises a ValueError upon + # an overflow, this seems reasonable. + with self.assertRaises(Exception) as cm: + EndlessJSONEncoder(check_circular=False).encode(5j) + self.assertIn(type(cm.exception), [RuntimeError, ValueError]) + + +class TestPyRecursion(TestRecursion, PyTest): pass +class TestCRecursion(TestRecursion, CTest): pass diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -731,6 +731,10 @@ return Py.makeCharacter(string.charAt(i)); } + public int getInt(int i) { + return string.charAt(i); + } + @Override protected PyObject getslice(int start, int stop, int step) { if (step > 0 && stop < start) { diff --git a/src/org/python/core/PyUnicode.java b/src/org/python/core/PyUnicode.java --- a/src/org/python/core/PyUnicode.java +++ b/src/org/python/core/PyUnicode.java @@ -731,6 +731,10 @@ return Py.makeCharacter(codepoint, true); } + public int getInt(int i) { + return getString().codePointAt(translator.utf16Index(i)); + } + private class SubsequenceIteratorImpl implements Iterator { private int current, k, stop, step; diff --git a/src/org/python/modules/Setup.java b/src/org/python/modules/Setup.java --- a/src/org/python/modules/Setup.java +++ b/src/org/python/modules/Setup.java @@ -34,6 +34,7 @@ "_functools:org.python.modules._functools._functools", "_hashlib", "_io:org.python.modules._io._io", + "_json:org.python.modules._json._json", "_jythonlib:org.python.modules._jythonlib._jythonlib", "_marshal", "_py_compile", diff --git a/src/org/python/modules/_json/Encoder.java b/src/org/python/modules/_json/Encoder.java new file mode 100644 --- /dev/null +++ b/src/org/python/modules/_json/Encoder.java @@ -0,0 +1,234 @@ +package org.python.modules._json; + +import org.python.core.ArgParser; +import org.python.core.Py; +import org.python.core.PyDictionary; +import org.python.core.PyException; +import org.python.core.PyFloat; +import org.python.core.PyInteger; +import org.python.core.PyLong; +import org.python.core.PyList; +import org.python.core.PyObject; +import org.python.core.PyString; +import org.python.core.PyTuple; +import org.python.core.PyType; +import org.python.core.PyUnicode; +import org.python.expose.ExposedGet; +import org.python.expose.ExposedType; + + at ExposedType(name = "_json.encoder", base = PyObject.class) +public class Encoder extends PyObject { + + public static final PyType TYPE = PyType.fromClass(Encoder.class); + + @ExposedGet + public final String __module__ = "_json"; + + final PyDictionary markers; + final PyObject defaultfn; + final PyObject encoder; + final PyObject indent; + final PyObject key_separator; + final PyObject item_separator; + final PyObject sort_keys; + final boolean skipkeys; + final boolean allow_nan; + + public Encoder(PyObject[] args, String[] kwds) { + super(); + ArgParser ap = new ArgParser("encoder", args, kwds, + new String[]{"markers", "default", "encoder", "indent", + "key_separator", "item_separator", "sort_keys", "skipkeys", "allow_nan"}); + ap.noKeywords(); + PyObject m = ap.getPyObject(0); + markers = m == Py.None ? null : (PyDictionary) m; + defaultfn = ap.getPyObject(1); + encoder = ap.getPyObject(2); + indent = ap.getPyObject(3); + key_separator = ap.getPyObject(4); + item_separator = ap.getPyObject(5); + sort_keys = ap.getPyObject(6); + skipkeys = ap.getPyObject(7).__nonzero__(); + allow_nan = ap.getPyObject(8).__nonzero__(); + } + + public PyObject __call__(PyObject obj) { + return __call__(obj, Py.Zero); + } + + public PyObject __call__(PyObject obj, PyObject indent_level) { + PyList rval = new PyList(); + listencode_obj(rval, obj, 0); + return rval; + } + + private PyString encode_float(PyObject obj) { + /* Return the JSON representation of a PyFloat */ + double i = obj.asDouble(); + if (Double.isInfinite(i) || Double.isNaN(i)) { + if (!allow_nan) { + throw Py.ValueError("Out of range float values are not JSON compliant"); + } + if (i == Double.POSITIVE_INFINITY) { + return new PyString("Infinity"); + } else if (i == Double.NEGATIVE_INFINITY) { + return new PyString("-Infinity"); + } else { + return new PyString("NaN"); + } + } + /* Use a better float format here? */ + return obj.__repr__(); + } + + private PyString encode_string(PyObject obj) { + /* Return the JSON representation of a string */ + return (PyString) encoder.__call__(obj); + } + + private void listencode_obj(PyList rval, PyObject obj, int indent_level) { + /* Encode Python object obj to a JSON term, rval is a PyList */ + if (obj == Py.None) { + rval.append(new PyString("null")); + } else if (obj == Py.True) { + rval.append(new PyString("true")); + } else if (obj == Py.False) { + rval.append(new PyString("false")); + } else if (obj instanceof PyString) { + rval.append(encode_string(obj)); + } else if (obj instanceof PyInteger || obj instanceof PyLong) { + rval.append(obj.__str__()); + } else if (obj instanceof PyFloat) { + rval.append(encode_float(obj)); + } else if (obj instanceof PyList || obj instanceof PyTuple) { + listencode_list(rval, obj, indent_level); + } else if (obj instanceof PyDictionary) { + listencode_dict(rval, (PyDictionary) obj, indent_level); + } else { + PyObject ident = null; + if (markers != null) { + boolean contained = false; + try { + contained = markers.__contains__(obj); + } catch (PyException pye) { + // ignore objects that are not hashable, so they can be + // potentially serialized with defaultfn + if (!pye.match(Py.TypeError)) throw pye; + } + if (contained) { + throw Py.ValueError("Circular reference detected"); + } + ident = Py.newInteger(Py.id(obj)); + markers.__setitem__(ident, obj); + } + if (defaultfn == Py.None) { + throw Py.TypeError(String.format(".80s is not JSON serializable", obj.__repr__())); + } + + PyObject newobj; + try { + newobj = defaultfn.__call__(obj); + } catch (StackOverflowError e) { + if (markers == Py.None) { + throw e; + } else { + throw Py.ValueError("Stack overflow in JSON serialization"); + } + } + listencode_obj(rval, newobj, indent_level); + if (ident != null) { + markers.__delitem__(ident); + } + } + } + + private void listencode_dict(PyList rval, PyDictionary dct, int indent_level) { + /* Encode Python dict dct a JSON term */ + + PyObject ident = null; + + if (dct.__len__() == 0) { + rval.append(new PyString("{}")); + return; + } + + if (markers != null) { + ident = Py.newInteger(Py.id(dct)); + if (markers.__contains__(ident)) { + throw Py.ValueError("Circular reference detected"); + } + markers.__setitem__(ident, dct); + } + rval.append(new PyString("{")); + + /* TODO: C speedup not implemented for sort_keys */ + + int idx = 0; + for (PyObject key : dct.asIterable()) { + PyString kstr; + + if (key instanceof PyString || key instanceof PyUnicode) { + kstr = (PyString) key; + } else if (key instanceof PyFloat) { + kstr = encode_float(key); + } else if (key instanceof PyInteger || key instanceof PyLong) { + kstr = key.__str__(); + } else if (key == Py.True) { + kstr = new PyString("true"); + } else if (key == Py.False) { + kstr = new PyString("false"); + } else if (key == Py.None) { + kstr = new PyString("null"); + } else if (skipkeys) { + continue; + } else { + throw Py.TypeError(String.format("keys must be a string: %.80s", key.__repr__())); + } + + if (idx > 0) { + rval.append(item_separator); + } + + PyObject value = dct.__getitem__(key); + PyString encoded = encode_string(kstr); + rval.append(encoded); + rval.append(key_separator); + listencode_obj(rval, value, indent_level); + idx += 1; + } + + if (ident != null) { + markers.__delitem__(ident); + } + rval.append(new PyString("}")); + } + + + private void listencode_list(PyList rval, PyObject seq, int indent_level) { + PyObject ident = null; + + if (markers != null) { + ident = Py.newInteger(Py.id(seq)); + if (markers.__contains__(ident)) { + throw Py.ValueError("Circular reference detected"); + } + markers.__setitem__(ident, seq); + } + + rval.append(new PyString("[")); + + int i = 0; + for (PyObject obj : seq.asIterable()) { + if (i > 0) { + rval.append(item_separator); + } + listencode_obj(rval, obj, indent_level); + i++; + } + + if (ident != null) { + markers.__delitem__(ident); + } + rval.append(new PyString("]")); + } +} diff --git a/src/org/python/modules/_json/Scanner.java b/src/org/python/modules/_json/Scanner.java new file mode 100644 --- /dev/null +++ b/src/org/python/modules/_json/Scanner.java @@ -0,0 +1,328 @@ +package org.python.modules._json; + +import org.python.core.Py; +import org.python.core.PyDictionary; +import org.python.core.PyList; +import org.python.core.PyObject; +import org.python.core.PyString; +import org.python.core.PyTuple; +import org.python.core.PyType; +import org.python.expose.ExposedGet; +import org.python.expose.ExposedType; + + + at ExposedType(name = "_json.Scanner", base = PyObject.class) +public class Scanner extends PyObject { + + public static final PyType TYPE = PyType.fromClass(Scanner.class); + + @ExposedGet + public final String __module__ = "_json"; + + final String encoding; + final boolean strict; + final PyObject object_hook; + final PyObject pairs_hook; + final PyObject parse_float; + final PyObject parse_int; + final PyObject parse_constant; + + public Scanner(PyObject context) { + super(); + PyObject encoding_obj = context.__getattr__("encoding"); + encoding = encoding_obj == Py.None ? "utf-8" : context.__getattr__("encoding").asString(); + strict = context.__getattr__("strict").__nonzero__(); + object_hook = context.__getattr__("object_hook"); + pairs_hook = context.__getattr__("object_pairs_hook"); + parse_float = context.__getattr__("parse_float"); + parse_int = context.__getattr__("parse_int"); + parse_constant = context.__getattr__("parse_constant"); + } + + public PyObject __call__(PyObject string, PyObject idx) { + return _scan_once((PyString)string, idx.asInt()); + } + + private static boolean IS_WHITESPACE(int c) { + return (c == ' ') || (c == '\t') || (c == '\n') || (c == '\r'); + } + + static PyTuple valIndex(PyObject obj, int i) { + return new PyTuple(obj, Py.newInteger(i)); + } + + public PyTuple _parse_object(PyString pystr, int idx) { // }, Py_ssize_t *next_idx_ptr) { + /* Read a JSON object from PyString pystr. + idx is the index of the first character after the opening curly brace. + + Returns a new PyTuple of a PyObject (usually a dict, but object_hook can change that) + and the next_idx to the first character after + the closing curly brace. + */ + PyString str = pystr; + int end_idx = pystr.__len__() - 1; + PyList pairs = new PyList(); + PyObject item; + PyObject key; + PyObject val; + + /* skip whitespace after { */ + while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++; + + /* only loop if the object is non-empty */ + if (idx <= end_idx && str.getInt(idx) != '}') { + while (idx <= end_idx) { + /* read key */ + if (str.getInt(idx) != '"') { + _json.raise_errmsg("Expecting property name", pystr, idx); + } + PyTuple key_idx = _json.scanstring(pystr, idx + 1, encoding, strict); + key = key_idx.pyget(0); + idx = key_idx.pyget(1).asInt(); + + /* skip whitespace between key and : delimiter, read :, skip whitespace */ + while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++; + if (idx > end_idx || str.getInt(idx) != ':') { + _json.raise_errmsg("Expecting : delimiter", pystr, idx); + } + idx++; + while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++; + + /* read any JSON data type */ + PyTuple val_idx = _scan_once(pystr, idx); + val = val_idx.pyget(0); + idx = val_idx.pyget(1).asInt(); + pairs.append(new PyTuple(key, val)); + + /* skip whitespace before } or , */ + while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++; + + /* bail if the object is closed or we didn't get the , delimiter */ + if (idx > end_idx) break; + if (str.getInt(idx) == '}') { + break; + } else if (str.getInt(idx) != ',') { + _json.raise_errmsg("Expecting , delimiter", pystr, idx); + } + idx++; + + /* skip whitespace after , delimiter */ + while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++; + } + } + /* verify that idx < end_idx, str[idx] should be '}' */ + if (idx > end_idx || str.getInt(idx) != '}') { + _json.raise_errmsg("Expecting object", pystr, end_idx); + } + + /* if pairs_hook is not None: rval = object_pairs_hook(pairs) */ + if (pairs_hook != Py.None) { + return valIndex(pairs_hook.__call__(pairs), idx + 1); + } + + PyObject rval = new PyDictionary(); + ((PyDictionary)rval).update(pairs); + + /* if object_hook is not None: rval = object_hook(rval) */ + if (object_hook != Py.None) { + rval = object_hook.__call__(rval); + } + + return valIndex(rval, idx + 1); + } + + public PyTuple _parse_array(PyString pystr, int idx) { + /* Read a JSON array from PyString pystr. + + + Returns a new PyTuple of a PyList and next_idx (first character after + the closing brace.) + */ + PyString str = pystr; + int end_idx = pystr.__len__() - 1; + PyList rval = new PyList(); + int next_idx; + + /* skip whitespace after [ */ + while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++; + + /* only loop if the array is non-empty */ + if (idx <= end_idx && str.getInt(idx) != ']') { + while (idx <= end_idx) { + + /* read any JSON term and de-tuplefy the (rval, idx) */ + PyTuple val_idx = _scan_once(pystr, idx); + PyObject val = val_idx.pyget(0); + idx = val_idx.pyget(1).asInt(); + rval.append(val); + + /* skip whitespace between term and , */ + while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++; + + /* bail if the array is closed or we didn't get the , delimiter */ + if (idx > end_idx) break; + if (str.getInt(idx) == ']') { + break; + } else if (str.getInt(idx) != ',') { + _json.raise_errmsg("Expecting , delimiter", pystr, idx); + } + idx++; + + /* skip whitespace after , */ + while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++; + } + } + + /* verify that idx < end_idx, str[idx] should be ']' */ + if (idx > end_idx || str.getInt(idx) != ']') { + _json.raise_errmsg("Expecting object", pystr, end_idx); + } + return valIndex(rval, idx + 1); + } + + + public PyTuple _scan_once(PyString pystr, int idx) { + /* Read one JSON term (of any kind) from PyString pystr. + idx is the index of the first character of the term + + Returns a new PyTuple of a PyObject representation of the term along + with the next_idx + */ + PyString str = pystr; + int length = pystr.__len__(); + if (idx >= length) { + throw Py.StopIteration(""); + } + switch (str.getInt(idx)) { + case '"': + /* string */ + return _json.scanstring(pystr, idx + 1, encoding, strict); + case '{': + /* object */ + return _parse_object(pystr, idx + 1); + case '[': + /* array */ + return _parse_array(pystr, idx + 1); + case 'n': + /* null */ + if ((idx + 3 < length) && str.getInt(idx + 1) == 'u' && str.getInt(idx + 2) == 'l' && str.getInt(idx + 3) == 'l') { + return valIndex(Py.None, idx + 4); + } + break; + case 't': + /* true */ + if ((idx + 3 < length) && str.getInt(idx + 1) == 'r' && str.getInt(idx + 2) == 'u' && str.getInt(idx + 3) == 'e') { + return valIndex(Py.True, idx + 4); + } + break; + case 'f': + /* false */ + if ((idx + 4 < length) && str.getInt(idx + 1) == 'a' && str.getInt(idx + 2) == 'l' && str.getInt(idx + 3) == 's' && str.getInt(idx + 4) == 'e') { + return valIndex(Py.False, idx + 5); + } + break; + case 'N': + /* NaN */ + if ((idx + 2 < length) && str.getInt(idx + 1) == 'a' && str.getInt(idx + 2) == 'N') { + return _parse_constant("NaN", idx + 3); + } + break; + case 'I': + /* Infinity */ + if ((idx + 7 < length) && str.getInt(idx + 1) == 'n' && str.getInt(idx + 2) == 'f' && str.getInt(idx + 3) == 'i' && str.getInt(idx + 4) == 'n' && str.getInt(idx + 5) == 'i' && str.getInt(idx + 6) == 't' && str.getInt(idx + 7) == 'y') { + return _parse_constant("Infinity", idx + 8); + } + break; + case '-': + /* -Infinity */ + if ((idx + 8 < length) && str.getInt(idx + 1) == 'I' && str.getInt(idx + 2) == 'n' && str.getInt(idx + 3) == 'f' && str.getInt(idx + 4) == 'i' && str.getInt(idx + 5) == 'n' && str.getInt(idx + 6) == 'i' && str.getInt(idx + 7) == 't' && str.getInt(idx + 8) == 'y') { + return _parse_constant("-Infinity", idx + 9); + } + break; + } + /* Didn't find a string, object, array, or named constant. Look for a number. */ + return _match_number(pystr, idx); + } + + public PyTuple _parse_constant(String constant, int idx) { + return valIndex(parse_constant.__call__(Py.newString(constant)), idx); + } + + public PyTuple _match_number(PyString pystr, int start) { + /* Read a JSON number from PyString pystr. + idx is the index of the first character of the number + + Returns a new PyObject representation of that number: + PyInt, PyLong, or PyFloat. + May return other types if parse_int or parse_float are set + along with index to the first character after + the number. + */ + PyString str = pystr; + int end_idx = pystr.__len__() - 1; + int idx = start; + boolean is_float = false; + + /* read a sign if it's there, make sure it's not the end of the string */ + if (str.getInt(idx) == '-') { + idx++; + if (idx > end_idx) { + throw Py.StopIteration(""); + } + } + + /* read as many integer digits as we find as long as it doesn't start with 0 */ + if (str.getInt(idx) >= '1' && str.getInt(idx) <= '9') { + idx++; + while (idx <= end_idx && str.getInt(idx) >= '0' && str.getInt(idx) <= '9') idx++; + } + /* if it starts with 0 we only expect one integer digit */ + else if (str.getInt(idx) == '0') { + idx++; + } + /* no integer digits, error */ + else { + throw Py.StopIteration(""); + } + + /* if the next char is '.' followed by a digit then read all float digits */ + if (idx < end_idx && str.getInt(idx) == '.' && str.getInt(idx + 1) >= '0' && str.getInt(idx + 1) <= '9') { + is_float = true; + idx += 2; + while (idx <= end_idx && str.getInt(idx) >= '0' && str.getInt(idx) <= '9') idx++; + } + + /* if the next char is 'e' or 'E' then maybe read the exponent (or backtrack) */ + if (idx < end_idx && (str.getInt(idx) == 'e' || str.getInt(idx) == 'E')) { + + /* save the index of the 'e' or 'E' just in case we need to backtrack */ + int e_start = idx; + idx++; + + /* read an exponent sign if present */ + if (idx < end_idx && (str.getInt(idx) == '-' || str.getInt(idx) == '+')) idx++; + + /* read all digits */ + while (idx <= end_idx && str.getInt(idx) >= '0' && str.getInt(idx) <= '9') idx++; + + /* if we got a digit, then parse as float. if not, backtrack */ + if (str.getInt(idx - 1) >= '0' && str.getInt(idx - 1) <= '9') { + is_float = true; + } else { + idx = e_start; + } + } + + /* copy the section we determined to be a number */ + PyString numstr = (PyString) str.__getslice__(Py.newInteger(start), Py.newInteger(idx)); + if (is_float) { + /* parse as a float using a fast path if available, otherwise call user defined method */ + return valIndex(parse_float.__call__(numstr), idx); + } else { + /* parse as an int using a fast path if available, otherwise call user defined method */ + return valIndex(parse_int.__call__(numstr), idx); + } + } + + +} diff --git a/src/org/python/modules/_json/_json.java b/src/org/python/modules/_json/_json.java new file mode 100644 --- /dev/null +++ b/src/org/python/modules/_json/_json.java @@ -0,0 +1,422 @@ +/* Copyright (c) Jython Developers */ +package org.python.modules._json; + +import org.python.core.ArgParser; +import org.python.core.ClassDictInit; +import org.python.core.Py; +import org.python.core.PyBuiltinFunctionNarrow; +import org.python.core.PyList; +import org.python.core.PyObject; +import org.python.core.PyString; +import org.python.core.PyTuple; +import org.python.core.PyUnicode; +import org.python.core.codecs; +import org.python.expose.ExposedGet; + +import java.util.Iterator; + +/** + * This module is a nearly exact line by line port of _json.c to Java. Names and comments are retained + * to make it easy to follow, but classes and methods are modified to following Java calling conventions. + * + * (Retained comments use the standard commenting convention for C.) + */ +public class _json implements ClassDictInit { + + public static final PyString __doc__ = new PyString("Port of _json C module."); + + public static void classDictInit(PyObject dict) { + dict.__setitem__("__name__", new PyString("_json")); + dict.__setitem__("__doc__", __doc__); + dict.__setitem__("encode_basestring_ascii", new EncodeBasestringAsciiFunction()); + dict.__setitem__("make_encoder", Encoder.TYPE); + dict.__setitem__("make_scanner", Scanner.TYPE); + dict.__setitem__("scanstring", new ScanstringFunction()); + dict.__setitem__("__module__", new PyString("_json")); + + // ensure __module__ is set properly in these modules, + // based on how the module name lookups are chained + Encoder.TYPE.setName("_json.Encoder"); + Scanner.TYPE.setName("_json.Scanner"); + + // Hide from Python + dict.__setitem__("classDictInit", null); + } + + private static PyObject errmsg_fn; + + private static synchronized PyObject get_errmsg_fn() { + if (errmsg_fn == null) { + PyObject json = org.python.core.__builtin__.__import__("json"); + if (json != null) { + PyObject decoder = json.__findattr__("decoder"); + if (decoder != null) { + errmsg_fn = decoder.__findattr__("errmsg"); + } + } + } + return errmsg_fn; + } + + static void raise_errmsg(String msg, PyObject s) { + raise_errmsg(msg, s, Py.None, Py.None); + } + + static void raise_errmsg(String msg, PyObject s, int pos) { + raise_errmsg(msg, s, Py.newInteger(pos), Py.None); + } + + static void raise_errmsg(String msg, PyObject s, PyObject pos, PyObject end) { + /* Use the Python function json.decoder.errmsg to raise a nice + looking ValueError exception */ + final PyObject errmsg_fn = get_errmsg_fn(); + if (errmsg_fn != null) { + throw Py.ValueError(errmsg_fn.__call__(Py.newString(msg), s, pos, end).asString()); + } else { + throw Py.ValueError(msg); + } + } + + static class ScanstringFunction extends PyBuiltinFunctionNarrow { + ScanstringFunction() { + super("scanstring", 2, 4, "scanstring"); + } + + @Override + @ExposedGet(name = "__module__") + public PyObject getModule() { + return new PyString("_json"); + } + + + @Override + public PyObject __call__(PyObject s, PyObject end) { + return __call__(s, end, new PyString("utf-8"), Py.True); + } + + @Override + public PyObject __call__(PyObject s, PyObject end, PyObject encoding) { + return __call__(s, end, encoding, Py.True); + } + + @Override + public PyObject __call__(PyObject[] args, String[] kwds) { + ArgParser ap = new ArgParser("scanstring", args, kwds, new String[]{ + "s", "end", "encoding", "strict"}, 2); + return __call__( + ap.getPyObject(0), + ap.getPyObject(1), + ap.getPyObject(2, new PyString("utf-8")), + ap.getPyObject(3, Py.True)); + } + + @Override + public PyObject __call__(PyObject s, PyObject end, PyObject encoding, PyObject strict) { + // but rethrow in case it does work - see the test case for issue 362 + int end_idx = end.asIndex(Py.OverflowError); + boolean is_strict = strict.__nonzero__(); + if (s instanceof PyString) { + return scanstring((PyString) s, end_idx, + encoding == Py.None ? null : encoding.toString(), is_strict); + } else { + throw Py.TypeError(String.format( + "first argument must be a string, not %.80s", + s.getType().fastGetName())); + } + } + + } + + static PyTuple scanstring(PyString pystr, int end, String encoding, boolean strict) { + int len = pystr.__len__(); + int begin = end - 1; + if (end < 0 || len <= end) { + throw Py.ValueError("end is out of bounds"); + } + int next; + final PyList chunks = new PyList(); + while (true) { + /* Find the end of the string or the next escape */ + int c = 0; + + for (next = end; next < len; next++) { + c = pystr.getInt(next); + if (c == '"' || c == '\\') { + break; + } else if (strict && c <= 0x1f) { + raise_errmsg("Invalid control character at", pystr, next); + } + } + if (!(c == '"' || c == '\\')) { + raise_errmsg("Unterminated string starting at", pystr, begin); + } + + /* Pick up this chunk if it's not zero length */ + if (next != end) { + PyString strchunk = (PyString) pystr.__getslice__(Py.newInteger(end), Py.newInteger(next)); + if (strchunk instanceof PyUnicode) { + chunks.append(strchunk); + } else { + chunks.append(codecs.decode(strchunk, encoding, null)); + } + } + next++; + if (c == '"') { + end = next; + break; + } + if (next == len) { + raise_errmsg("Unterminated string starting at", pystr, begin); + } + c = pystr.getInt(next); + if (c != 'u') { + /* Non-unicode backslash escapes */ + end = next + 1; + switch (c) { + case '"': + break; + case '\\': + break; + case '/': + break; + case 'b': + c = '\b'; + break; + case 'f': + c = '\f'; + break; + case 'n': + c = '\n'; + break; + case 'r': + c = '\r'; + break; + case 't': + c = '\t'; + break; + default: + c = 0; + } + if (c == 0) { + raise_errmsg("Invalid \\escape", pystr, end - 2); + } + } else { + c = 0; + next++; + end = next + 4; + if (end >= len) { + raise_errmsg("Invalid \\uXXXX escape", pystr, next - 1); + } + /* Decode 4 hex digits */ + for (; next < end; next++) { + int digit = pystr.getInt(next); + c <<= 4; + switch (digit) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + c |= (digit - '0'); + break; + case 'a': + case 'b': + case 'c': + case 'd': + case 'e': + case 'f': + c |= (digit - 'a' + 10); + break; + case 'A': + case 'B': + case 'C': + case 'D': + case 'E': + case 'F': + c |= (digit - 'A' + 10); + break; + default: + raise_errmsg("Invalid \\uXXXX escape", pystr, end - 5); + } + } + /* Surrogate pair */ + if ((c & 0xfc00) == 0xd800) { + int c2 = 0; + if (end + 6 >= len) { + raise_errmsg("Unpaired high surrogate", pystr, end - 5); + } + if (pystr.getInt(next++) != '\\' || pystr.getInt(next++) != 'u') { + raise_errmsg("Unpaired high surrogate", pystr, end - 5); + } + end += 6; + /* Decode 4 hex digits */ + for (; next < end; next++) { + int digit = pystr.getInt(next); + c2 <<= 4; + switch (digit) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + c2 |= (digit - '0'); + break; + case 'a': + case 'b': + case 'c': + case 'd': + case 'e': + case 'f': + c2 |= (digit - 'a' + 10); + break; + case 'A': + case 'B': + case 'C': + case 'D': + case 'E': + case 'F': + c2 |= (digit - 'A' + 10); + break; + default: + raise_errmsg("Invalid \\uXXXX escape", pystr, end - 5); + } + } + if ((c2 & 0xfc00) != 0xdc00) { + raise_errmsg("Unpaired high surrogate", pystr, end - 5); + } + c = 0x10000 + (((c - 0xd800) << 10) | (c2 - 0xdc00)); + } else if ((c & 0xfc00) == 0xdc00) { + raise_errmsg("Unpaired low surrogate", pystr, end - 5); + } + } + chunks.append(new PyUnicode(c)); + } + + return new PyTuple(Py.EmptyUnicode.join(chunks), Py.newInteger(end)); + } + + static class EncodeBasestringAsciiFunction extends PyBuiltinFunctionNarrow { + EncodeBasestringAsciiFunction() { + super("encode_basestring_ascii", 1, 1, "encode_basestring_ascii"); + } + + @Override + @ExposedGet(name = "__module__") + public PyObject getModule() { + return new PyString("_json"); + } + + @Override + public PyObject __call__(PyObject pystr) { + return encode_basestring_ascii(pystr); + } + } + + static PyString encode_basestring_ascii(PyObject pystr) { + if (pystr instanceof PyUnicode) { + return ascii_escape((PyUnicode) pystr); + } else if (pystr instanceof PyString) { + return ascii_escape((PyString) pystr); + } else { + throw Py.TypeError(String.format( + "first argument must be a string, not %.80s", + pystr.getType().fastGetName())); + } + } + + private static PyString ascii_escape(PyUnicode pystr) { + StringBuilder rval = new StringBuilder(pystr.__len__()); + rval.append("\""); + for (Iterator iter = pystr.newSubsequenceIterator(); iter.hasNext(); ) { + _write_char(rval, iter.next()); + } + rval.append("\""); + return new PyString(rval.toString()); + } + + private static PyString ascii_escape(PyString pystr) { + int len = pystr.__len__(); + String s = pystr.getString(); + StringBuilder rval = new StringBuilder(len); + rval.append("\""); + for (int i = 0; i < len; i++) { + int c = s.charAt(i); + if (c > 127) { + return ascii_escape(new PyUnicode(codecs.PyUnicode_DecodeUTF8(s, null))); + } + _write_char(rval, c); + } + rval.append("\""); + return new PyString(rval.toString()); + } + + private static void _write_char(StringBuilder builder, int c) { + /* Escape unicode code point c to ASCII escape sequences + in char *output. output must have at least 12 bytes unused to + accommodate an escaped surrogate pair "\ u XXXX \ u XXXX" */ + if (c >= ' ' && c <= '~' && c != '\\' & c != '"') { + builder.append((char) c); + } else { + _ascii_escape_char(builder, c); + } + } + + private static void _write_hexchar(StringBuilder builder, int c) { + builder.append("0123456789abcdef".charAt(c & 0xf)); + } + + private static void _ascii_escape_char(StringBuilder builder, int c) { + builder.append('\\'); + switch (c) { + case '\\': + builder.append((char) c); + break; + case '"': + builder.append((char) c); + break; + case '\b': + builder.append('b'); + break; + case '\f': + builder.append('f'); + break; + case '\n': + builder.append('n'); + break; + case '\r': + builder.append('r'); + break; + case '\t': + builder.append('t'); + break; + default: + if (c >= 0x10000) { + /* UTF-16 surrogate pair */ + int v = c - 0x10000; + c = 0xd800 | ((v >> 10) & 0x3ff); + builder.append('u'); + _write_hexchar(builder, c >> 12); + _write_hexchar(builder, c >> 8); + _write_hexchar(builder, c >> 4); + _write_hexchar(builder, c); + c = 0xdc00 | (v & 0x3ff); + builder.append('\\'); + } + builder.append('u'); + _write_hexchar(builder, c >> 12); + _write_hexchar(builder, c >> 8); + _write_hexchar(builder, c >> 4); + _write_hexchar(builder, c); + } + } +} -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Mon Dec 8 02:59:25 2014 From: jython-checkins at python.org (jim.baker) Date: Mon, 08 Dec 2014 01:59:25 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Check_functions_in_=5Fcodec?= =?utf-8?q?s_if_str_params_are_ascii?= Message-ID: <20141208015922.21260.22245@psf.io> https://hg.python.org/jython/rev/cd5838289304 changeset: 7434:cd5838289304 user: Jim Baker date: Sun Dec 07 18:59:17 2014 -0700 summary: Check functions in _codecs if str params are ascii files: Lib/test/test_json.py | 20 --------- src/org/python/modules/_codecs.java | 36 ++++++++++++---- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/Lib/test/test_json.py b/Lib/test/test_json.py deleted file mode 100644 --- a/Lib/test/test_json.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Tests for json. - -The tests for json are defined in the json.tests package; -the test_suite() function there returns a test suite that's ready to -be run. -""" - -import json.tests -import test.test_support - -from json.tests.test_unicode import TestUnicode - -def test_main(): - #FIXME: Investigate why test_bad_encoding isn't working in Jython. - del TestUnicode.test_bad_encoding - test.test_support.run_unittest(json.tests.test_suite()) - - -if __name__ == "__main__": - test_main() diff --git a/src/org/python/modules/_codecs.java b/src/org/python/modules/_codecs.java --- a/src/org/python/modules/_codecs.java +++ b/src/org/python/modules/_codecs.java @@ -36,12 +36,28 @@ codecs.register(search_function); } - public static PyTuple lookup(String encoding) { - return codecs.lookup(encoding); + private static String _castString(PyString pystr) { + // Jython used to treat String as equivalent to PyString, or maybe PyUnicode, as + // it made sense. We need to be more careful now! Insert this cast check as necessary + // to ensure the appropriate compliance. + if (pystr == null) { + return null; + } + String s = pystr.toString(); + if (pystr instanceof PyUnicode) { + return s; + } else { + // May throw UnicodeEncodeError, per CPython behavior + return codecs.PyUnicode_EncodeASCII(s, s.length(), null); + } } - public static PyObject lookup_error(String handlerName) { - return codecs.lookup_error(handlerName); + public static PyTuple lookup(PyString encoding) { + return codecs.lookup(_castString(encoding)); + } + + public static PyObject lookup_error(PyString handlerName) { + return codecs.lookup_error(_castString(handlerName)); } public static void register_error(String name, PyObject errorHandler) { @@ -68,7 +84,7 @@ * @param encoding name of encoding (to look up in codec registry) * @return Unicode string decoded from bytes */ - public static PyObject decode(PyString bytes, String encoding) { + public static PyObject decode(PyString bytes, PyString encoding) { return decode(bytes, encoding, null); } @@ -85,8 +101,8 @@ * @param errors error policy name (e.g. "ignore") * @return Unicode string decoded from bytes */ - public static PyObject decode(PyString bytes, String encoding, String errors) { - return codecs.decode(bytes, encoding, errors); + public static PyObject decode(PyString bytes, PyString encoding, PyString errors) { + return codecs.decode(bytes, _castString(encoding), _castString(errors)); } /** @@ -109,7 +125,7 @@ * @param encoding name of encoding (to look up in codec registry) * @return bytes object encoding unicode */ - public static PyString encode(PyUnicode unicode, String encoding) { + public static PyString encode(PyUnicode unicode, PyString encoding) { return encode(unicode, encoding, null); } @@ -126,8 +142,8 @@ * @param errors error policy name (e.g. "ignore") * @return bytes object encoding unicode */ - public static PyString encode(PyUnicode unicode, String encoding, String errors) { - return Py.newString(codecs.encode(unicode, encoding, errors)); + public static PyString encode(PyUnicode unicode, PyString encoding, PyString errors) { + return Py.newString(codecs.encode(unicode, _castString(encoding), _castString(errors))); } /* --- Some codec support methods -------------------------------------------- */ -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Mon Dec 8 04:01:54 2014 From: jython-checkins at python.org (jim.baker) Date: Mon, 08 Dec 2014 03:01:54 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Support_sys=2Egetprofile=28?= =?utf-8?q?=29_and_sys=2Egettrace=28=29?= Message-ID: <20141208030142.27505.10067@psf.io> https://hg.python.org/jython/rev/d281114823af changeset: 7435:d281114823af user: Jim Baker date: Sun Dec 07 20:01:36 2014 -0700 summary: Support sys.getprofile() and sys.gettrace() These functions were added in Python 2.6 files: Lib/test/test_profilehooks.py | 2 - src/org/python/core/PySystemState.java | 18 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_profilehooks.py b/Lib/test/test_profilehooks.py --- a/Lib/test/test_profilehooks.py +++ b/Lib/test/test_profilehooks.py @@ -17,11 +17,9 @@ def tearDown(self): sys.setprofile(None) - @unittest.skip("FIXME: broken") def test_empty(self): assert sys.getprofile() == None - @unittest.skip("FIXME: broken") def test_setget(self): def fn(*args): pass diff --git a/src/org/python/core/PySystemState.java b/src/org/python/core/PySystemState.java --- a/src/org/python/core/PySystemState.java +++ b/src/org/python/core/PySystemState.java @@ -428,6 +428,15 @@ this.recursionlimit = recursionlimit; } + public PyObject gettrace() { + ThreadState ts = Py.getThreadState(); + if (ts.tracefunc == null) { + return Py.None; + } else { + return ((PythonTraceFunction)ts.tracefunc).tracefunc; + } + } + public void settrace(PyObject tracefunc) { ThreadState ts = Py.getThreadState(); if (tracefunc == Py.None) { @@ -437,6 +446,15 @@ } } + public PyObject getprofile() { + ThreadState ts = Py.getThreadState(); + if (ts.profilefunc == null) { + return Py.None; + } else { + return ((PythonTraceFunction)ts.profilefunc).tracefunc; + } + } + public void setprofile(PyObject profilefunc) { ThreadState ts = Py.getThreadState(); if (profilefunc == Py.None) { -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 9 01:08:38 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 09 Dec 2014 00:08:38 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Run_all_itertools_tests=2C_?= =?utf-8?q?removing_Jython-specific_skips?= Message-ID: <20141209000831.74609.58395@psf.io> https://hg.python.org/jython/rev/5ee9b24f3d9d changeset: 7436:5ee9b24f3d9d user: Jim Baker date: Mon Dec 08 17:08:17 2014 -0700 summary: Run all itertools tests, removing Jython-specific skips Updated count, repeat, and tee in itertools to support various corner cases seen in usage. files: Lib/test/test_itertools.py | 112 ++++----- src/org/python/core/__builtin__.java | 2 +- src/org/python/modules/itertools/PyTeeIterator.java | 2 +- src/org/python/modules/itertools/count.java | 92 ++++++- src/org/python/modules/itertools/repeat.java | 5 + 5 files changed, 132 insertions(+), 81 deletions(-) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -1,5 +1,6 @@ import unittest from test import test_support +from test.test_weakref import extra_collect from itertools import * from weakref import proxy from decimal import Decimal @@ -336,11 +337,8 @@ self.assertEqual(take(2, zip('abc',count(-3))), [('a', -3), ('b', -2)]) self.assertRaises(TypeError, count, 2, 3, 4) self.assertRaises(TypeError, count, 'a') - - #FIXME: not working in Jython - #self.assertEqual(list(islice(count(maxsize-5), 10)), range(maxsize-5, maxsize+5)) - #self.assertEqual(list(islice(count(-maxsize-5), 10)), range(-maxsize-5, -maxsize+5)) - + self.assertEqual(list(islice(count(maxsize-5), 10)), range(maxsize-5, maxsize+5)) + self.assertEqual(list(islice(count(-maxsize-5), 10)), range(-maxsize-5, -maxsize+5)) c = count(3) self.assertEqual(repr(c), 'count(3)') c.next() @@ -348,27 +346,20 @@ c = count(-9) self.assertEqual(repr(c), 'count(-9)') c.next() + self.assertEqual(repr(count(10.25)), 'count(10.25)') + self.assertEqual(c.next(), -8) + for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5): + # Test repr (ignoring the L in longs) + r1 = repr(count(i)).replace('L', '') + r2 = 'count(%r)'.__mod__(i).replace('L', '') + self.assertEqual(r1, r2) - #FIXME: not working in Jython - #self.assertEqual(repr(count(10.25)), 'count(10.25)') - self.assertEqual(c.next(), -8) - - #FIXME: not working in Jython - if not test_support.is_jython: - for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5): - # Test repr (ignoring the L in longs) - r1 = repr(count(i)).replace('L', '') - r2 = 'count(%r)'.__mod__(i).replace('L', '') - self.assertEqual(r1, r2) - - #FIXME: not working in Jython # check copy, deepcopy, pickle - if not test_support.is_jython: - for value in -3, 3, sys.maxint-5, sys.maxint+5: - c = count(value) - self.assertEqual(next(copy.copy(c)), value) - self.assertEqual(next(copy.deepcopy(c)), value) - self.assertEqual(next(pickle.loads(pickle.dumps(c))), value) + for value in -3, 3, sys.maxint-5, sys.maxint+5: + c = count(value) + self.assertEqual(next(copy.copy(c)), value) + self.assertEqual(next(copy.deepcopy(c)), value) + self.assertEqual(next(pickle.loads(pickle.dumps(c))), value) def test_count_with_stride(self): self.assertEqual(zip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)]) @@ -378,17 +369,14 @@ [('a', 0), ('b', -1), ('c', -2)]) self.assertEqual(zip('abc',count(2,0)), [('a', 2), ('b', 2), ('c', 2)]) self.assertEqual(zip('abc',count(2,1)), [('a', 2), ('b', 3), ('c', 4)]) - - #FIXME: not working in Jython - #self.assertEqual(take(20, count(maxsize-15, 3)), take(20, range(maxsize-15, maxsize+100, 3))) - #self.assertEqual(take(20, count(-maxsize-15, 3)), take(20, range(-maxsize-15,-maxsize+100, 3))) - #self.assertEqual(take(3, count(2, 3.25-4j)), [2, 5.25-4j, 8.5-8j]) - #self.assertEqual(take(3, count(Decimal('1.1'), Decimal('.1'))), - # [Decimal('1.1'), Decimal('1.2'), Decimal('1.3')]) - #self.assertEqual(take(3, count(Fraction(2,3), Fraction(1,7))), - # [Fraction(2,3), Fraction(17,21), Fraction(20,21)]) - #self.assertEqual(repr(take(3, count(10, 2.5))), repr([10, 12.5, 15.0])) - + self.assertEqual(take(20, count(maxsize-15, 3)), take(20, range(maxsize-15, maxsize+100, 3))) + self.assertEqual(take(20, count(-maxsize-15, 3)), take(20, range(-maxsize-15,-maxsize+100, 3))) + self.assertEqual(take(3, count(2, 3.25-4j)), [2, 5.25-4j, 8.5-8j]) + self.assertEqual(take(3, count(Decimal('1.1'), Decimal('.1'))), + [Decimal('1.1'), Decimal('1.2'), Decimal('1.3')]) + self.assertEqual(take(3, count(Fraction(2,3), Fraction(1,7))), + [Fraction(2,3), Fraction(17,21), Fraction(20,21)]) + self.assertEqual(repr(take(3, count(10, 2.5))), repr([10, 12.5, 15.0])) c = count(3, 5) self.assertEqual(repr(c), 'count(3, 5)') c.next() @@ -402,23 +390,18 @@ c.next() self.assertEqual(repr(c), 'count(-12, -3)') self.assertEqual(repr(c), 'count(-12, -3)') - - #FIXME: not working in Jython - #self.assertEqual(repr(count(10.5, 1.25)), 'count(10.5, 1.25)') - #self.assertEqual(repr(count(10.5, 1)), 'count(10.5)') # suppress step=1 when it's an int - #self.assertEqual(repr(count(10.5, 1.00)), 'count(10.5, 1.0)') # do show float values lilke 1.0 - - #FIXME: not working in Jython - if not test_support.is_jython: - for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5): - for j in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 1, 10, sys.maxint-5, sys.maxint+5): - # Test repr (ignoring the L in longs) - r1 = repr(count(i, j)).replace('L', '') - if j == 1: - r2 = ('count(%r)' % i).replace('L', '') - else: - r2 = ('count(%r, %r)' % (i, j)).replace('L', '') - self.assertEqual(r1, r2) + self.assertEqual(repr(count(10.5, 1.25)), 'count(10.5, 1.25)') + self.assertEqual(repr(count(10.5, 1)), 'count(10.5)') # suppress step=1 when it's an int + self.assertEqual(repr(count(10.5, 1.00)), 'count(10.5, 1.0)') # do show float values lilke 1.0 + for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5): + for j in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 1, 10, sys.maxint-5, sys.maxint+5): + # Test repr (ignoring the L in longs) + r1 = repr(count(i, j)).replace('L', '') + if j == 1: + r2 = ('count(%r)' % i).replace('L', '') + else: + r2 = ('count(%r, %r)' % (i, j)).replace('L', '') + self.assertEqual(r1, r2) def test_cycle(self): self.assertEqual(take(10, cycle('abc')), list('abcabcabca')) @@ -918,14 +901,25 @@ self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc')) # test that tee objects are weak referencable - a, b = tee(xrange(10)) - p = proxy(a) - self.assertEqual(getattr(p, '__class__'), type(b)) - del a + def delocalize(): + # local variables in Jython cannot be deleted to see + # objects go out of scope immediately. Except for tests + # like this however this is not going to be observed! + a, b = tee(xrange(10)) + return dict(a=a, b=b) - #FIXME: not working in Jython - if not test_support.is_jython: - self.assertRaises(ReferenceError, getattr, p, '__class__') + d = delocalize() + p = proxy(d['a']) + self.assertEqual(getattr(p, '__class__'), type(d['b'])) + del d['a'] + extra_collect() # necessary for Jython to ensure ref queue is cleared out + self.assertRaises(ReferenceError, getattr, p, '__class__') + + # Issue 13454: Crash when deleting backward iterator from tee() + def test_tee_del_backward(self): + forward, backward = tee(repeat(None, 20000000)) + any(forward) # exhaust the iterator + del backward def test_StopIteration(self): self.assertRaises(StopIteration, izip().next) diff --git a/src/org/python/core/__builtin__.java b/src/org/python/core/__builtin__.java --- a/src/org/python/core/__builtin__.java +++ b/src/org/python/core/__builtin__.java @@ -972,7 +972,7 @@ } try { // See PyXRange.getLenOfRange for the primitive version - PyObject diff = hi.__sub__(lo).__sub__(Py.One); + PyObject diff = hi._sub(lo)._sub(Py.One); PyObject n = diff.__floordiv__(step).__add__(Py.One); return n.asInt(); } catch (PyException pye) { diff --git a/src/org/python/modules/itertools/PyTeeIterator.java b/src/org/python/modules/itertools/PyTeeIterator.java --- a/src/org/python/modules/itertools/PyTeeIterator.java +++ b/src/org/python/modules/itertools/PyTeeIterator.java @@ -94,7 +94,7 @@ throw Py.ValueError("n must be >= 0"); } - PyObject[] tees = new PyTeeIterator[n]; + PyObject[] tees = new PyObject[n]; if (n == 0) { return tees; diff --git a/src/org/python/modules/itertools/count.java b/src/org/python/modules/itertools/count.java --- a/src/org/python/modules/itertools/count.java +++ b/src/org/python/modules/itertools/count.java @@ -3,12 +3,14 @@ import org.python.core.ArgParser; import org.python.core.Py; +import org.python.core.PyException; import org.python.core.PyInteger; import org.python.core.PyIterator; import org.python.core.PyObject; import org.python.core.PyString; import org.python.core.PyTuple; import org.python.core.PyType; +import org.python.core.__builtin__; import org.python.expose.ExposedNew; import org.python.expose.ExposedMethod; import org.python.expose.ExposedType; @@ -18,8 +20,16 @@ public static final PyType TYPE = PyType.fromClass(count.class); private PyIterator iter; - private int counter; - private int stepper; + private PyObject counter; + private PyObject stepper; + + private static PyObject NumberClass; + private static synchronized PyObject getNumberClass() { + if (NumberClass == null) { + NumberClass = __builtin__.__import__("numbers").__getattr__("Number"); + } + return NumberClass; + } public static final String count_doc = "count(start=0, step=1) --> count object\n\n" + @@ -37,62 +47,104 @@ } /** - * Creates an iterator that returns consecutive integers starting at 0. + * Creates an iterator that returns consecutive numbers starting at 0. */ public count() { super(); - count___init__(0, 1); + count___init__(Py.Zero, Py.One); } /** - * Creates an iterator that returns consecutive integers starting at start. + * Creates an iterator that returns consecutive numbers starting at start. */ - public count(final int start) { + public count(final PyObject start) { super(); - count___init__(start, 1); + count___init__(start, Py.One); } /** - * Creates an iterator that returns consecutive integers starting at start with step step. + * Creates an iterator that returns consecutive numbers starting at start with step step. */ - public count(final int start, final int step) { + public count(final PyObject start, final PyObject step) { super(); count___init__(start, step); } + // TODO: move into Py, although NumberClass import time resolution becomes + // TODO: a bit trickier + private static PyObject getNumber(PyObject obj) { + if (Py.isInstance(obj, getNumberClass())) { + return obj; + } + try { + PyObject intObj = obj.__int__(); + if (Py.isInstance(obj, getNumberClass())) { + return intObj; + } + throw Py.TypeError("a number is required"); + } catch (PyException exc) { + if (exc.match(Py.ValueError)) { + throw Py.TypeError("a number is required"); + } + throw exc; + } + } + @ExposedNew @ExposedMethod final void count___init__(final PyObject[] args, String[] kwds) { ArgParser ap = new ArgParser("count", args, kwds, new String[] {"start", "step"}, 0); - - int start = ap.getInt(0, 0); - int step = ap.getInt(1, 1); + PyObject start = getNumber(ap.getPyObject(0, Py.Zero)); + PyObject step = getNumber(ap.getPyObject(1, Py.One)); count___init__(start, step); } - private void count___init__(final int start, final int step) { + private void count___init__(final PyObject start, final PyObject step) { counter = start; stepper = step; iter = new PyIterator() { public PyObject __iternext__() { - int result = counter; - counter += stepper; - return new PyInteger(result); + PyObject result = counter; + counter = counter._add(stepper); + return result; } }; } @ExposedMethod + public PyObject count___copy__() { + return new count(counter, stepper); + } + + @ExposedMethod + final PyObject count___reduce_ex__(PyObject protocol) { + return __reduce_ex__(protocol); + } + + @ExposedMethod + final PyObject count___reduce__() { + return __reduce_ex__(Py.Zero); + } + + + public PyObject __reduce_ex__(PyObject protocol) { + if (stepper == Py.One) { + return new PyTuple(getType(), new PyTuple(counter)); + } else { + return new PyTuple(getType(), new PyTuple(counter, stepper)); + } + } + + @ExposedMethod public PyString __repr__() { - if (stepper == 1) { - return (PyString)(Py.newString("count(%d)").__mod__(Py.newInteger(counter))); + if (stepper instanceof PyInteger && stepper._cmp(Py.One) == 0) { + return Py.newString(String.format("count(%s)", counter)); } else { - return (PyString)(Py.newString("count(%d, %d)").__mod__(new PyTuple( - Py.newInteger(counter), Py.newInteger(stepper)))); + return Py.newString(String.format("count(%s, %s)", counter, stepper)); } } diff --git a/src/org/python/modules/itertools/repeat.java b/src/org/python/modules/itertools/repeat.java --- a/src/org/python/modules/itertools/repeat.java +++ b/src/org/python/modules/itertools/repeat.java @@ -97,6 +97,11 @@ } @ExposedMethod + final PyObject __copy__() { + return new repeat(object, counter); + } + + @ExposedMethod public int __len__() { if (counter < 0) { throw Py.TypeError("object of type 'itertools.repeat' has no len()"); -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 9 02:45:40 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 09 Dec 2014 01:45:40 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Fix_slice_indices_computati?= =?utf-8?q?on_and_support_slice_pickling?= Message-ID: <20141209014537.50876.91077@psf.io> https://hg.python.org/jython/rev/4191d2014781 changeset: 7437:4191d2014781 user: Jim Baker date: Mon Dec 08 18:45:17 2014 -0700 summary: Fix slice indices computation and support slice pickling files: Lib/test/test_slice.py | 137 ------------------- src/org/python/core/PySlice.java | 86 ++++++----- 2 files changed, 49 insertions(+), 174 deletions(-) diff --git a/Lib/test/test_slice.py b/Lib/test/test_slice.py deleted file mode 100644 --- a/Lib/test/test_slice.py +++ /dev/null @@ -1,137 +0,0 @@ -# tests for slice objects; in particular the indices method. - -import unittest -from test import test_support -from cPickle import loads, dumps - -import sys - -class SliceTest(unittest.TestCase): - - def test_constructor(self): - self.assertRaises(TypeError, slice) - self.assertRaises(TypeError, slice, 1, 2, 3, 4) - - def test_repr(self): - self.assertEqual(repr(slice(1, 2, 3)), "slice(1, 2, 3)") - - def test_hash(self): - # Verify clearing of SF bug #800796 - self.assertRaises(TypeError, hash, slice(5)) - self.assertRaises(TypeError, slice(5).__hash__) - - def test_cmp(self): - s1 = slice(1, 2, 3) - s2 = slice(1, 2, 3) - s3 = slice(1, 2, 4) - self.assertEqual(s1, s2) - self.assertNotEqual(s1, s3) - - class Exc(Exception): - pass - - class BadCmp(object): - def __eq__(self, other): - raise Exc - __hash__ = None # Silence Py3k warning - - s1 = slice(BadCmp()) - s2 = slice(BadCmp()) - self.assertRaises(Exc, cmp, s1, s2) - self.assertEqual(s1, s1) - - s1 = slice(1, BadCmp()) - s2 = slice(1, BadCmp()) - self.assertEqual(s1, s1) - self.assertRaises(Exc, cmp, s1, s2) - - s1 = slice(1, 2, BadCmp()) - s2 = slice(1, 2, BadCmp()) - self.assertEqual(s1, s1) - self.assertRaises(Exc, cmp, s1, s2) - - def test_members(self): - s = slice(1) - self.assertEqual(s.start, None) - self.assertEqual(s.stop, 1) - self.assertEqual(s.step, None) - - s = slice(1, 2) - self.assertEqual(s.start, 1) - self.assertEqual(s.stop, 2) - self.assertEqual(s.step, None) - - s = slice(1, 2, 3) - self.assertEqual(s.start, 1) - self.assertEqual(s.stop, 2) - self.assertEqual(s.step, 3) - - class AnyClass: - pass - - obj = AnyClass() - s = slice(obj) - self.assertTrue(s.stop is obj) - - def test_indices(self): - self.assertEqual(slice(None ).indices(10), (0, 10, 1)) - self.assertEqual(slice(None, None, 2).indices(10), (0, 10, 2)) - self.assertEqual(slice(1, None, 2).indices(10), (1, 10, 2)) - self.assertEqual(slice(None, None, -1).indices(10), (9, -1, -1)) - self.assertEqual(slice(None, None, -2).indices(10), (9, -1, -2)) - self.assertEqual(slice(3, None, -2).indices(10), (3, -1, -2)) - # issue 3004 tests - self.assertEqual(slice(None, -9).indices(10), (0, 1, 1)) - #FIXME: next two not correct on Jython - #self.assertEqual(slice(None, -10).indices(10), (0, 0, 1)) - #self.assertEqual(slice(None, -11).indices(10), (0, 0, 1)) - self.assertEqual(slice(None, -10, -1).indices(10), (9, 0, -1)) - self.assertEqual(slice(None, -11, -1).indices(10), (9, -1, -1)) - self.assertEqual(slice(None, -12, -1).indices(10), (9, -1, -1)) - self.assertEqual(slice(None, 9).indices(10), (0, 9, 1)) - self.assertEqual(slice(None, 10).indices(10), (0, 10, 1)) - self.assertEqual(slice(None, 11).indices(10), (0, 10, 1)) - self.assertEqual(slice(None, 8, -1).indices(10), (9, 8, -1)) - self.assertEqual(slice(None, 9, -1).indices(10), (9, 9, -1)) - #FIXME: next not correct on Jython - #self.assertEqual(slice(None, 10, -1).indices(10), (9, 9, -1)) - - self.assertEqual( - slice(-100, 100 ).indices(10), - slice(None).indices(10) - ) - self.assertEqual( - slice(100, -100, -1).indices(10), - slice(None, None, -1).indices(10) - ) - self.assertEqual(slice(-100L, 100L, 2L).indices(10), (0, 10, 2)) - - self.assertEqual(range(10)[::sys.maxint - 1], [0]) - - self.assertRaises(OverflowError, slice(None).indices, 1L<<100) - - def test_setslice_without_getslice(self): - tmp = [] - class X(object): - def __setslice__(self, i, j, k): - tmp.append((i, j, k)) - - x = X() - with test_support.check_py3k_warnings(): - x[1:2] = 42 - self.assertEqual(tmp, [(1, 2, 42)]) - - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") - def test_pickle(self): - s = slice(10, 20, 3) - for protocol in (0,1,2): - t = loads(dumps(s, protocol)) - self.assertEqual(s, t) - self.assertEqual(s.indices(15), t.indices(15)) - self.assertNotEqual(id(s), id(t)) - -def test_main(): - test_support.run_unittest(SliceTest) - -if __name__ == "__main__": - test_main() diff --git a/src/org/python/core/PySlice.java b/src/org/python/core/PySlice.java --- a/src/org/python/core/PySlice.java +++ b/src/org/python/core/PySlice.java @@ -113,62 +113,64 @@ * @return an array with the start at index 0, stop at index 1, step at index 2 and * slicelength at index 3 */ - public int[] indicesEx(int len) { - int start; - int stop; - int step; - int slicelength; + public int[] indicesEx(int length) { + /* The corresponding C code (PySlice_GetIndicesEx) states: + * "this is harder to get right than you might think" + * As a consequence, I have chosen to copy the code and translate to Java. + * Note *rstart, etc., become result_start - the usual changes we need + * when going from pointers to corresponding Java. + */ - if (getStep() == Py.None) { - step = 1; + int defstart, defstop; + int result_start, result_stop, result_step, result_slicelength; + + if (step == Py.None) { + result_step = 1; } else { - step = calculateSliceIndex(getStep()); - if (step == 0) { + result_step = calculateSliceIndex(step); + if (result_step == 0) { throw Py.ValueError("slice step cannot be zero"); } } - if (getStart() == Py.None) { - start = step < 0 ? len - 1 : 0; + defstart = result_step < 0 ? length - 1 : 0; + defstop = result_step < 0 ? -1 : length; + + if (start == Py.None) { + result_start = defstart; } else { - start = calculateSliceIndex(getStart()); - if (start < 0) { - start += len; - } - if (start < 0) { - start = step < 0 ? -1 : 0; - } - if (start >= len) { - start = step < 0 ? len - 1 : len; + result_start = calculateSliceIndex(start); + if (result_start < 0) result_start += length; + if (result_start < 0) result_start = (result_step < 0) ? -1 : 0; + if (result_start >= length) { + result_start = (result_step < 0) ? length - 1 : length; } } - if (getStop() == Py.None) { - stop = step < 0 ? -1 : len; + if (stop == Py.None) { + result_stop = defstop; } else { - stop = calculateSliceIndex(getStop()); - if (stop < 0) { - stop += len; - } - if (stop < 0) { - stop = -1; - } - if (stop > len) { - stop = len; + result_stop = calculateSliceIndex(stop); + if (result_stop < 0) result_stop += length; + if (result_stop < 0) result_stop = (result_step < 0) ? -1 : 0; + if (result_stop >= length) { + result_stop = (result_step < 0) ? length - 1 : length; } } - if ((step < 0 && stop >= start) || (step > 0 && start >= stop)) { - slicelength = 0; - } else if (step < 0) { - slicelength = (stop - start + 1) / (step) + 1; + if ((result_step < 0 && result_stop >= result_start) + || (result_step > 0 && result_start >= result_stop)) { + result_slicelength = 0; + } else if (result_step < 0) { + result_slicelength = (result_stop - result_start + 1) / (result_step) + 1; } else { - slicelength = (stop - start - 1) / (step) + 1; + result_slicelength = (result_stop - result_start - 1) / (result_step) + 1; } - return new int[] {start, stop, step, slicelength}; + return new int[]{result_start, result_stop, result_step, result_slicelength}; } + /** * Calculate indices for the deprecated __get/set/delslice__ methods. * @@ -230,4 +232,14 @@ public final PyObject getStep() { return step; } + + @ExposedMethod + final PyObject slice___reduce__() { + return new PyTuple(getType(), new PyTuple(start, stop, step)); + } + + @ExposedMethod(defaults = "Py.None") + final PyObject slice___reduce_ex__(PyObject protocol) { + return new PyTuple(getType(), new PyTuple(start, stop, step)); + } } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 9 04:41:23 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 09 Dec 2014 03:41:23 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Fix_flaky_tests_in_test=5Fw?= =?utf-8?q?eakref_and_test=5Fweakset?= Message-ID: <20141209034123.50874.2934@psf.io> https://hg.python.org/jython/rev/8833d46c7dd6 changeset: 7438:8833d46c7dd6 user: Jim Baker date: Mon Dec 08 20:41:17 2014 -0700 summary: Fix flaky tests in test_weakref and test_weakset Weak reference callbacks are called by a separate reaper thread, therefore tests have to be sensitive to possible races. Although arguably making this more robust by appropriate interleaving of gc.collect() and time.sleep() is not good enough, these tests are as good as we can make without making them more JVM dependent. files: Lib/test/test_weakref.py | 2 +- Lib/test/test_weakset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -153,7 +153,7 @@ o2 = C() ref3 = weakref.proxy(o2) del o2 - gc.collect() + extra_collect() self.assertRaises(weakref.ReferenceError, bool, ref3) self.assertTrue(self.cbcalled == 2) diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py --- a/Lib/test/test_weakset.py +++ b/Lib/test/test_weakset.py @@ -391,7 +391,7 @@ # We have removed either the first consumed items, or another one self.assertIn(len(list(it)), [len(items), len(items) - 1]) del it - gc.collect() + extra_collect() # The removal has been committed self.assertEqual(len(s), len(items)) -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 9 04:53:25 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 09 Dec 2014 03:53:25 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Run_as_many_of_the_2=2E7_te?= =?utf-8?q?sts_in_test=5Fsort_as_possible?= Message-ID: <20141209035324.657.95879@psf.io> https://hg.python.org/jython/rev/139bf89578c0 changeset: 7439:139bf89578c0 user: Jim Baker date: Mon Dec 08 20:53:20 2014 -0700 summary: Run as many of the 2.7 tests in test_sort as possible Jython uses the underlying TimSort implementation in java.util.Arrays.sort (on objects, not primitives), so this means there are a few Python tests around illegal usages such as random comparison functions that do not behave the same. Given that this is for verifying things like not segfaulting, this is not something we need to duplicate as well. files: Lib/test/test_sort.py | 17 +++++++++++++---- 1 files changed, 13 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_sort.py b/Lib/test/test_sort.py --- a/Lib/test/test_sort.py +++ b/Lib/test/test_sort.py @@ -2,6 +2,10 @@ import random import sys import unittest +try: + import java +except ImportError: + pass verbose = test_support.verbose nerrors = 0 @@ -39,8 +43,6 @@ return class TestBase(unittest.TestCase): - @unittest.skipIf(test_support.is_jython, - "FIXME: find the part that is too much for Jython.") def testStressfully(self): # Try a variety of sizes at and around powers of 2, and at powers of 10. sizes = [0] @@ -102,8 +104,15 @@ print " Checking against an insane comparison function." print " If the implementation isn't careful, this may segfault." s = x[:] - s.sort(lambda a, b: int(random.random() * 3) - 1) - check("an insane function left some permutation", x, s) + + if test_support.is_jython: + try: + s.sort(lambda a, b: int(random.random() * 3) - 1) + except java.lang.IllegalArgumentException: + pass + else: + s.sort(lambda a, b: int(random.random() * 3) - 1) + check("an insane function left some permutation", x, s) x = [Complains(i) for i in x] s = x[:] -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 9 05:00:49 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 09 Dec 2014 04:00:49 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Java_does_not_allow_access_?= =?utf-8?q?to_object_sizes?= Message-ID: <20141209040048.54512.14681@psf.io> https://hg.python.org/jython/rev/ba029506fbb2 changeset: 7440:ba029506fbb2 user: Jim Baker date: Mon Dec 08 21:00:44 2014 -0700 summary: Java does not allow access to object sizes __basicsize__ and __itemsize__ are not available in Jython; see http://bugs.jython.org/issue1017 files: Lib/test/test_types.py | 2 +- 1 files changed, 1 insertions(+), 1 deletions(-) diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -741,7 +741,7 @@ for code in 'xXobns': self.assertRaises(ValueError, format, 0, ',' + code) - @unittest.skipIf(is_jython, "FIXME: not working") + @unittest.skipIf(is_jython, "Java does not allow access to object sizes") def test_internal_sizes(self): self.assertGreater(object.__basicsize__, 0) self.assertGreater(tuple.__itemsize__, 0) -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 9 05:15:32 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 09 Dec 2014 04:15:32 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Triage_email_tests_using_CJ?= =?utf-8?q?K_encodings?= Message-ID: <20141209041528.125263.38807@psf.io> https://hg.python.org/jython/rev/56b86504c0a5 changeset: 7441:56b86504c0a5 user: Jim Baker date: Mon Dec 08 21:15:23 2014 -0700 summary: Triage email tests using CJK encodings Jython uses Java-based codecs for CJK/multibytecodec support, but this is not as robust for trailing bytes as seen in test_email and test_email_renamed. This should be looked at again post 2.7.0. files: Lib/email/test/test_email.py | 3561 ++++++++++++++ Lib/email/test/test_email_renamed.py | 3332 +++++++++++++ 2 files changed, 6893 insertions(+), 0 deletions(-) diff --git a/Lib/email/test/test_email.py b/Lib/email/test/test_email.py new file mode 100644 --- /dev/null +++ b/Lib/email/test/test_email.py @@ -0,0 +1,3561 @@ +# Copyright (C) 2001-2010 Python Software Foundation +# Contact: email-sig at python.org +# email package unit tests + +import os +import sys +import time +import base64 +import difflib +import unittest +import warnings +import textwrap +from cStringIO import StringIO + +import email + +from email.Charset import Charset +from email.Header import Header, decode_header, make_header +from email.Parser import Parser, HeaderParser +from email.Generator import Generator, DecodedGenerator +from email.Message import Message +from email.MIMEAudio import MIMEAudio +from email.MIMEText import MIMEText +from email.MIMEImage import MIMEImage +from email.MIMEBase import MIMEBase +from email.MIMEMessage import MIMEMessage +from email.MIMEMultipart import MIMEMultipart +from email import Utils +from email import Errors +from email import Encoders +from email import Iterators +from email import base64MIME +from email import quopriMIME + +from test.test_support import findfile, run_unittest +from email.test import __file__ as landmark +from test.test_support import is_jython + +NL = '\n' +EMPTYSTRING = '' +SPACE = ' ' + + + +def openfile(filename, mode='r'): + path = os.path.join(os.path.dirname(landmark), 'data', filename) + return open(path, mode) + + + +# Base test class +class TestEmailBase(unittest.TestCase): + def ndiffAssertEqual(self, first, second): + """Like assertEqual except use ndiff for readable output.""" + if first != second: + sfirst = str(first) + ssecond = str(second) + diff = difflib.ndiff(sfirst.splitlines(), ssecond.splitlines()) + fp = StringIO() + print >> fp, NL, NL.join(diff) + raise self.failureException, fp.getvalue() + + def _msgobj(self, filename): + fp = openfile(findfile(filename)) + try: + msg = email.message_from_file(fp) + finally: + fp.close() + return msg + + + +# Test various aspects of the Message class's API +class TestMessageAPI(TestEmailBase): + def test_get_all(self): + eq = self.assertEqual + msg = self._msgobj('msg_20.txt') + eq(msg.get_all('cc'), ['ccc at zzz.org', 'ddd at zzz.org', 'eee at zzz.org']) + eq(msg.get_all('xx', 'n/a'), 'n/a') + + def test_getset_charset(self): + eq = self.assertEqual + msg = Message() + eq(msg.get_charset(), None) + charset = Charset('iso-8859-1') + msg.set_charset(charset) + eq(msg['mime-version'], '1.0') + eq(msg.get_content_type(), 'text/plain') + eq(msg['content-type'], 'text/plain; charset="iso-8859-1"') + eq(msg.get_param('charset'), 'iso-8859-1') + eq(msg['content-transfer-encoding'], 'quoted-printable') + eq(msg.get_charset().input_charset, 'iso-8859-1') + # Remove the charset + msg.set_charset(None) + eq(msg.get_charset(), None) + eq(msg['content-type'], 'text/plain') + # Try adding a charset when there's already MIME headers present + msg = Message() + msg['MIME-Version'] = '2.0' + msg['Content-Type'] = 'text/x-weird' + msg['Content-Transfer-Encoding'] = 'quinted-puntable' + msg.set_charset(charset) + eq(msg['mime-version'], '2.0') + eq(msg['content-type'], 'text/x-weird; charset="iso-8859-1"') + eq(msg['content-transfer-encoding'], 'quinted-puntable') + + def test_set_charset_from_string(self): + eq = self.assertEqual + msg = Message() + msg.set_charset('us-ascii') + eq(msg.get_charset().input_charset, 'us-ascii') + eq(msg['content-type'], 'text/plain; charset="us-ascii"') + + def test_set_payload_with_charset(self): + msg = Message() + charset = Charset('iso-8859-1') + msg.set_payload('This is a string payload', charset) + self.assertEqual(msg.get_charset().input_charset, 'iso-8859-1') + + def test_get_charsets(self): + eq = self.assertEqual + + msg = self._msgobj('msg_08.txt') + charsets = msg.get_charsets() + eq(charsets, [None, 'us-ascii', 'iso-8859-1', 'iso-8859-2', 'koi8-r']) + + msg = self._msgobj('msg_09.txt') + charsets = msg.get_charsets('dingbat') + eq(charsets, ['dingbat', 'us-ascii', 'iso-8859-1', 'dingbat', + 'koi8-r']) + + msg = self._msgobj('msg_12.txt') + charsets = msg.get_charsets() + eq(charsets, [None, 'us-ascii', 'iso-8859-1', None, 'iso-8859-2', + 'iso-8859-3', 'us-ascii', 'koi8-r']) + + def test_get_filename(self): + eq = self.assertEqual + + msg = self._msgobj('msg_04.txt') + filenames = [p.get_filename() for p in msg.get_payload()] + eq(filenames, ['msg.txt', 'msg.txt']) + + msg = self._msgobj('msg_07.txt') + subpart = msg.get_payload(1) + eq(subpart.get_filename(), 'dingusfish.gif') + + def test_get_filename_with_name_parameter(self): + eq = self.assertEqual + + msg = self._msgobj('msg_44.txt') + filenames = [p.get_filename() for p in msg.get_payload()] + eq(filenames, ['msg.txt', 'msg.txt']) + + def test_get_boundary(self): + eq = self.assertEqual + msg = self._msgobj('msg_07.txt') + # No quotes! + eq(msg.get_boundary(), 'BOUNDARY') + + def test_set_boundary(self): + eq = self.assertEqual + # This one has no existing boundary parameter, but the Content-Type: + # header appears fifth. + msg = self._msgobj('msg_01.txt') + msg.set_boundary('BOUNDARY') + header, value = msg.items()[4] + eq(header.lower(), 'content-type') + eq(value, 'text/plain; charset="us-ascii"; boundary="BOUNDARY"') + # This one has a Content-Type: header, with a boundary, stuck in the + # middle of its headers. Make sure the order is preserved; it should + # be fifth. + msg = self._msgobj('msg_04.txt') + msg.set_boundary('BOUNDARY') + header, value = msg.items()[4] + eq(header.lower(), 'content-type') + eq(value, 'multipart/mixed; boundary="BOUNDARY"') + # And this one has no Content-Type: header at all. + msg = self._msgobj('msg_03.txt') + self.assertRaises(Errors.HeaderParseError, + msg.set_boundary, 'BOUNDARY') + + def test_make_boundary(self): + msg = MIMEMultipart('form-data') + # Note that when the boundary gets created is an implementation + # detail and might change. + self.assertEqual(msg.items()[0][1], 'multipart/form-data') + # Trigger creation of boundary + msg.as_string() + self.assertEqual(msg.items()[0][1][:33], + 'multipart/form-data; boundary="==') + # XXX: there ought to be tests of the uniqueness of the boundary, too. + + def test_message_rfc822_only(self): + # Issue 7970: message/rfc822 not in multipart parsed by + # HeaderParser caused an exception when flattened. + fp = openfile(findfile('msg_46.txt')) + msgdata = fp.read() + parser = email.Parser.HeaderParser() + msg = parser.parsestr(msgdata) + out = StringIO() + gen = email.Generator.Generator(out, True, 0) + gen.flatten(msg, False) + self.assertEqual(out.getvalue(), msgdata) + + def test_get_decoded_payload(self): + eq = self.assertEqual + msg = self._msgobj('msg_10.txt') + # The outer message is a multipart + eq(msg.get_payload(decode=True), None) + # Subpart 1 is 7bit encoded + eq(msg.get_payload(0).get_payload(decode=True), + 'This is a 7bit encoded message.\n') + # Subpart 2 is quopri + eq(msg.get_payload(1).get_payload(decode=True), + '\xa1This is a Quoted Printable encoded message!\n') + # Subpart 3 is base64 + eq(msg.get_payload(2).get_payload(decode=True), + 'This is a Base64 encoded message.') + # Subpart 4 is base64 with a trailing newline, which + # used to be stripped (issue 7143). + eq(msg.get_payload(3).get_payload(decode=True), + 'This is a Base64 encoded message.\n') + # Subpart 5 has no Content-Transfer-Encoding: header. + eq(msg.get_payload(4).get_payload(decode=True), + 'This has no Content-Transfer-Encoding: header.\n') + + def test_get_decoded_uu_payload(self): + eq = self.assertEqual + msg = Message() + msg.set_payload('begin 666 -\n+:&5L;&\\@=V]R;&0 \n \nend\n') + for cte in ('x-uuencode', 'uuencode', 'uue', 'x-uue'): + msg['content-transfer-encoding'] = cte + eq(msg.get_payload(decode=True), 'hello world') + # Now try some bogus data + msg.set_payload('foo') + eq(msg.get_payload(decode=True), 'foo') + + def test_decode_bogus_uu_payload_quietly(self): + msg = Message() + msg.set_payload('begin 664 foo.txt\n%' % i for i in range(10)]) + msg.set_payload('Test') + sfp = StringIO() + g = Generator(sfp) + g.flatten(msg) + eq(sfp.getvalue(), """\ +From: test at dom.ain +References: <0 at dom.ain> <1 at dom.ain> <2 at dom.ain> <3 at dom.ain> <4 at dom.ain> + <5 at dom.ain> <6 at dom.ain> <7 at dom.ain> <8 at dom.ain> <9 at dom.ain> + +Test""") + + def test_no_split_long_header(self): + eq = self.ndiffAssertEqual + hstr = 'References: ' + 'x' * 80 + h = Header(hstr, continuation_ws='\t') + eq(h.encode(), """\ +References: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx""") + + def test_splitting_multiple_long_lines(self): + eq = self.ndiffAssertEqual + hstr = """\ +from babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for ; Sat, 2 Feb 2002 17:00:06 -0800 (PST) +\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for ; Sat, 2 Feb 2002 17:00:06 -0800 (PST) +\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for ; Sat, 2 Feb 2002 17:00:06 -0800 (PST) +""" + h = Header(hstr, continuation_ws='\t') + eq(h.encode(), """\ +from babylon.socal-raves.org (localhost [127.0.0.1]); +\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; +\tfor ; +\tSat, 2 Feb 2002 17:00:06 -0800 (PST) +\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); +\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; +\tfor ; +\tSat, 2 Feb 2002 17:00:06 -0800 (PST) +\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); +\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; +\tfor ; +\tSat, 2 Feb 2002 17:00:06 -0800 (PST)""") + + def test_splitting_first_line_only_is_long(self): + eq = self.ndiffAssertEqual + hstr = """\ +from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93] helo=cthulhu.gerg.ca) +\tby kronos.mems-exchange.org with esmtp (Exim 4.05) +\tid 17k4h5-00034i-00 +\tfor test at mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400""" + h = Header(hstr, maxlinelen=78, header_name='Received', + continuation_ws='\t') + eq(h.encode(), """\ +from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93] +\thelo=cthulhu.gerg.ca) +\tby kronos.mems-exchange.org with esmtp (Exim 4.05) +\tid 17k4h5-00034i-00 +\tfor test at mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400""") + + def test_long_8bit_header(self): + eq = self.ndiffAssertEqual + msg = Message() + h = Header('Britische Regierung gibt', 'iso-8859-1', + header_name='Subject') + h.append('gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte') + msg['Subject'] = h + eq(msg.as_string(), """\ +Subject: =?iso-8859-1?q?Britische_Regierung_gibt?= =?iso-8859-1?q?gr=FCnes?= + =?iso-8859-1?q?_Licht_f=FCr_Offshore-Windkraftprojekte?= + +""") + + def test_long_8bit_header_no_charset(self): + eq = self.ndiffAssertEqual + msg = Message() + msg['Reply-To'] = 'Britische Regierung gibt gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte ' + eq(msg.as_string(), """\ +Reply-To: Britische Regierung gibt gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte + +""") + + def test_long_to_header(self): + eq = self.ndiffAssertEqual + to = '"Someone Test #A" ,,"Someone Test #B" , "Someone Test #C" , "Someone Test #D" ' + msg = Message() + msg['To'] = to + eq(msg.as_string(0), '''\ +To: "Someone Test #A" , , + "Someone Test #B" , + "Someone Test #C" , + "Someone Test #D" + +''') + + def test_long_line_after_append(self): + eq = self.ndiffAssertEqual + s = 'This is an example of string which has almost the limit of header length.' + h = Header(s) + h.append('Add another line.') + eq(h.encode(), """\ +This is an example of string which has almost the limit of header length. + Add another line.""") + + def test_shorter_line_with_append(self): + eq = self.ndiffAssertEqual + s = 'This is a shorter line.' + h = Header(s) + h.append('Add another sentence. (Surprise?)') + eq(h.encode(), + 'This is a shorter line. Add another sentence. (Surprise?)') + + def test_long_field_name(self): + eq = self.ndiffAssertEqual + fn = 'X-Very-Very-Very-Long-Header-Name' + gs = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. " + h = Header(gs, 'iso-8859-1', header_name=fn) + # BAW: this seems broken because the first line is too long + eq(h.encode(), """\ +=?iso-8859-1?q?Die_Mieter_treten_hier_?= + =?iso-8859-1?q?ein_werden_mit_einem_Foerderband_komfortabel_den_Korridor_?= + =?iso-8859-1?q?entlang=2C_an_s=FCdl=FCndischen_Wandgem=E4lden_vorbei=2C_g?= + =?iso-8859-1?q?egen_die_rotierenden_Klingen_bef=F6rdert=2E_?=""") + + def test_long_received_header(self): + h = 'from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP; Wed, 05 Mar 2003 18:10:18 -0700' + msg = Message() + msg['Received-1'] = Header(h, continuation_ws='\t') + msg['Received-2'] = h + self.assertEqual(msg.as_string(), """\ +Received-1: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by +\throthgar.la.mastaler.com (tmda-ofmipd) with ESMTP; +\tWed, 05 Mar 2003 18:10:18 -0700 +Received-2: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by + hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP; + Wed, 05 Mar 2003 18:10:18 -0700 + +""") + + def test_string_headerinst_eq(self): + h = '<15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de> (David Bremner\'s message of "Thu, 6 Mar 2003 13:58:21 +0100")' + msg = Message() + msg['Received'] = Header(h, header_name='Received', + continuation_ws='\t') + msg['Received'] = h + self.ndiffAssertEqual(msg.as_string(), """\ +Received: <15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de> +\t(David Bremner's message of "Thu, 6 Mar 2003 13:58:21 +0100") +Received: <15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de> + (David Bremner's message of "Thu, 6 Mar 2003 13:58:21 +0100") + +""") + + def test_long_unbreakable_lines_with_continuation(self): + eq = self.ndiffAssertEqual + msg = Message() + t = """\ + iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9 + locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp""" + msg['Face-1'] = t + msg['Face-2'] = Header(t, header_name='Face-2') + eq(msg.as_string(), """\ +Face-1: iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9 + locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp +Face-2: iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9 + locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp + +""") + + def test_another_long_multiline_header(self): + eq = self.ndiffAssertEqual + m = '''\ +Received: from siimage.com ([172.25.1.3]) by zima.siliconimage.com with Microsoft SMTPSVC(5.0.2195.4905); + Wed, 16 Oct 2002 07:41:11 -0700''' + msg = email.message_from_string(m) + eq(msg.as_string(), '''\ +Received: from siimage.com ([172.25.1.3]) by zima.siliconimage.com with + Microsoft SMTPSVC(5.0.2195.4905); Wed, 16 Oct 2002 07:41:11 -0700 + +''') + + def test_long_lines_with_different_header(self): + eq = self.ndiffAssertEqual + h = """\ +List-Unsubscribe: , + """ + msg = Message() + msg['List'] = h + msg['List'] = Header(h, header_name='List') + eq(msg.as_string(), """\ +List: List-Unsubscribe: , + +List: List-Unsubscribe: , + + +""") + + + +# Test mangling of "From " lines in the body of a message +class TestFromMangling(unittest.TestCase): + def setUp(self): + self.msg = Message() + self.msg['From'] = 'aaa at bbb.org' + self.msg.set_payload("""\ +From the desk of A.A.A.: +Blah blah blah +""") + + def test_mangled_from(self): + s = StringIO() + g = Generator(s, mangle_from_=True) + g.flatten(self.msg) + self.assertEqual(s.getvalue(), """\ +From: aaa at bbb.org + +>From the desk of A.A.A.: +Blah blah blah +""") + + def test_dont_mangle_from(self): + s = StringIO() + g = Generator(s, mangle_from_=False) + g.flatten(self.msg) + self.assertEqual(s.getvalue(), """\ +From: aaa at bbb.org + +From the desk of A.A.A.: +Blah blah blah +""") + + def test_mangle_from_in_preamble_and_epilog(self): + s = StringIO() + g = Generator(s, mangle_from_=True) + msg = email.message_from_string(textwrap.dedent("""\ + From: foo at bar.com + Mime-Version: 1.0 + Content-Type: multipart/mixed; boundary=XXX + + From somewhere unknown + + --XXX + Content-Type: text/plain + + foo + + --XXX-- + + From somewhere unknowable + """)) + g.flatten(msg) + self.assertEqual(len([1 for x in s.getvalue().split('\n') + if x.startswith('>From ')]), 2) + + +# Test the basic MIMEAudio class +class TestMIMEAudio(unittest.TestCase): + def setUp(self): + # Make sure we pick up the audiotest.au that lives in email/test/data. + # In Python, there's an audiotest.au living in Lib/test but that isn't + # included in some binary distros that don't include the test + # package. The trailing empty string on the .join() is significant + # since findfile() will do a dirname(). + datadir = os.path.join(os.path.dirname(landmark), 'data', '') + fp = open(findfile('audiotest.au', datadir), 'rb') + try: + self._audiodata = fp.read() + finally: + fp.close() + self._au = MIMEAudio(self._audiodata) + + def test_guess_minor_type(self): + self.assertEqual(self._au.get_content_type(), 'audio/basic') + + def test_encoding(self): + payload = self._au.get_payload() + self.assertEqual(base64.decodestring(payload), self._audiodata) + + def test_checkSetMinor(self): + au = MIMEAudio(self._audiodata, 'fish') + self.assertEqual(au.get_content_type(), 'audio/fish') + + def test_add_header(self): + eq = self.assertEqual + unless = self.assertTrue + self._au.add_header('Content-Disposition', 'attachment', + filename='audiotest.au') + eq(self._au['content-disposition'], + 'attachment; filename="audiotest.au"') + eq(self._au.get_params(header='content-disposition'), + [('attachment', ''), ('filename', 'audiotest.au')]) + eq(self._au.get_param('filename', header='content-disposition'), + 'audiotest.au') + missing = [] + eq(self._au.get_param('attachment', header='content-disposition'), '') + unless(self._au.get_param('foo', failobj=missing, + header='content-disposition') is missing) + # Try some missing stuff + unless(self._au.get_param('foobar', missing) is missing) + unless(self._au.get_param('attachment', missing, + header='foobar') is missing) + + + +# Test the basic MIMEImage class +class TestMIMEImage(unittest.TestCase): + def setUp(self): + fp = openfile('PyBanner048.gif') + try: + self._imgdata = fp.read() + finally: + fp.close() + self._im = MIMEImage(self._imgdata) + + def test_guess_minor_type(self): + self.assertEqual(self._im.get_content_type(), 'image/gif') + + def test_encoding(self): + payload = self._im.get_payload() + self.assertEqual(base64.decodestring(payload), self._imgdata) + + def test_checkSetMinor(self): + im = MIMEImage(self._imgdata, 'fish') + self.assertEqual(im.get_content_type(), 'image/fish') + + def test_add_header(self): + eq = self.assertEqual + unless = self.assertTrue + self._im.add_header('Content-Disposition', 'attachment', + filename='dingusfish.gif') + eq(self._im['content-disposition'], + 'attachment; filename="dingusfish.gif"') + eq(self._im.get_params(header='content-disposition'), + [('attachment', ''), ('filename', 'dingusfish.gif')]) + eq(self._im.get_param('filename', header='content-disposition'), + 'dingusfish.gif') + missing = [] + eq(self._im.get_param('attachment', header='content-disposition'), '') + unless(self._im.get_param('foo', failobj=missing, + header='content-disposition') is missing) + # Try some missing stuff + unless(self._im.get_param('foobar', missing) is missing) + unless(self._im.get_param('attachment', missing, + header='foobar') is missing) + + + +# Test the basic MIMEText class +class TestMIMEText(unittest.TestCase): + def setUp(self): + self._msg = MIMEText('hello there') + + def test_types(self): + eq = self.assertEqual + unless = self.assertTrue + eq(self._msg.get_content_type(), 'text/plain') + eq(self._msg.get_param('charset'), 'us-ascii') + missing = [] + unless(self._msg.get_param('foobar', missing) is missing) + unless(self._msg.get_param('charset', missing, header='foobar') + is missing) + + def test_payload(self): + self.assertEqual(self._msg.get_payload(), 'hello there') + self.assertTrue(not self._msg.is_multipart()) + + def test_charset(self): + eq = self.assertEqual + msg = MIMEText('hello there', _charset='us-ascii') + eq(msg.get_charset().input_charset, 'us-ascii') + eq(msg['content-type'], 'text/plain; charset="us-ascii"') + + def test_7bit_unicode_input(self): + eq = self.assertEqual + msg = MIMEText(u'hello there', _charset='us-ascii') + eq(msg.get_charset().input_charset, 'us-ascii') + eq(msg['content-type'], 'text/plain; charset="us-ascii"') + + def test_7bit_unicode_input_no_charset(self): + eq = self.assertEqual + msg = MIMEText(u'hello there') + eq(msg.get_charset(), 'us-ascii') + eq(msg['content-type'], 'text/plain; charset="us-ascii"') + self.assertTrue('hello there' in msg.as_string()) + + def test_8bit_unicode_input(self): + teststr = u'\u043a\u0438\u0440\u0438\u043b\u0438\u0446\u0430' + eq = self.assertEqual + msg = MIMEText(teststr, _charset='utf-8') + eq(msg.get_charset().output_charset, 'utf-8') + eq(msg['content-type'], 'text/plain; charset="utf-8"') + eq(msg.get_payload(decode=True), teststr.encode('utf-8')) + + def test_8bit_unicode_input_no_charset(self): + teststr = u'\u043a\u0438\u0440\u0438\u043b\u0438\u0446\u0430' + self.assertRaises(UnicodeEncodeError, MIMEText, teststr) + + + +# Test complicated multipart/* messages +class TestMultipart(TestEmailBase): + def setUp(self): + fp = openfile('PyBanner048.gif') + try: + data = fp.read() + finally: + fp.close() + + container = MIMEBase('multipart', 'mixed', boundary='BOUNDARY') + image = MIMEImage(data, name='dingusfish.gif') + image.add_header('content-disposition', 'attachment', + filename='dingusfish.gif') + intro = MIMEText('''\ +Hi there, + +This is the dingus fish. +''') + container.attach(intro) + container.attach(image) + container['From'] = 'Barry ' + container['To'] = 'Dingus Lovers ' + container['Subject'] = 'Here is your dingus fish' + + now = 987809702.54848599 + timetuple = time.localtime(now) + if timetuple[-1] == 0: + tzsecs = time.timezone + else: + tzsecs = time.altzone + if tzsecs > 0: + sign = '-' + else: + sign = '+' + tzoffset = ' %s%04d' % (sign, tzsecs // 36) + container['Date'] = time.strftime( + '%a, %d %b %Y %H:%M:%S', + time.localtime(now)) + tzoffset + self._msg = container + self._im = image + self._txt = intro + + def test_hierarchy(self): + # convenience + eq = self.assertEqual + unless = self.assertTrue + raises = self.assertRaises + # tests + m = self._msg + unless(m.is_multipart()) + eq(m.get_content_type(), 'multipart/mixed') + eq(len(m.get_payload()), 2) + raises(IndexError, m.get_payload, 2) + m0 = m.get_payload(0) + m1 = m.get_payload(1) + unless(m0 is self._txt) + unless(m1 is self._im) + eq(m.get_payload(), [m0, m1]) + unless(not m0.is_multipart()) + unless(not m1.is_multipart()) + + def test_empty_multipart_idempotent(self): + text = """\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + + +--BOUNDARY + + +--BOUNDARY-- +""" + msg = Parser().parsestr(text) + self.ndiffAssertEqual(text, msg.as_string()) + + def test_no_parts_in_a_multipart_with_none_epilogue(self): + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.set_boundary('BOUNDARY') + self.ndiffAssertEqual(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY + +--BOUNDARY--''') + + def test_no_parts_in_a_multipart_with_empty_epilogue(self): + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.preamble = '' + outer.epilogue = '' + outer.set_boundary('BOUNDARY') + self.ndiffAssertEqual(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + + +--BOUNDARY + +--BOUNDARY-- +''') + + def test_one_part_in_a_multipart(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.set_boundary('BOUNDARY') + msg = MIMEText('hello world') + outer.attach(msg) + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY--''') + + def test_seq_parts_in_a_multipart_with_empty_preamble(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.preamble = '' + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY--''') + + + def test_seq_parts_in_a_multipart_with_none_preamble(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.preamble = None + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY--''') + + + def test_seq_parts_in_a_multipart_with_none_epilogue(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.epilogue = None + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY--''') + + + def test_seq_parts_in_a_multipart_with_empty_epilogue(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.epilogue = '' + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY-- +''') + + + def test_seq_parts_in_a_multipart_with_nl_epilogue(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.epilogue = '\n' + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY-- + +''') + + def test_message_external_body(self): + eq = self.assertEqual + msg = self._msgobj('msg_36.txt') + eq(len(msg.get_payload()), 2) + msg1 = msg.get_payload(1) + eq(msg1.get_content_type(), 'multipart/alternative') + eq(len(msg1.get_payload()), 2) + for subpart in msg1.get_payload(): + eq(subpart.get_content_type(), 'message/external-body') + eq(len(subpart.get_payload()), 1) + subsubpart = subpart.get_payload(0) + eq(subsubpart.get_content_type(), 'text/plain') + + def test_double_boundary(self): + # msg_37.txt is a multipart that contains two dash-boundary's in a + # row. Our interpretation of RFC 2046 calls for ignoring the second + # and subsequent boundaries. + msg = self._msgobj('msg_37.txt') + self.assertEqual(len(msg.get_payload()), 3) + + def test_nested_inner_contains_outer_boundary(self): + eq = self.ndiffAssertEqual + # msg_38.txt has an inner part that contains outer boundaries. My + # interpretation of RFC 2046 (based on sections 5.1 and 5.1.2) say + # these are illegal and should be interpreted as unterminated inner + # parts. + msg = self._msgobj('msg_38.txt') + sfp = StringIO() + Iterators._structure(msg, sfp) + eq(sfp.getvalue(), """\ +multipart/mixed + multipart/mixed + multipart/alternative + text/plain + text/plain + text/plain + text/plain +""") + + def test_nested_with_same_boundary(self): + eq = self.ndiffAssertEqual + # msg 39.txt is similarly evil in that it's got inner parts that use + # the same boundary as outer parts. Again, I believe the way this is + # parsed is closest to the spirit of RFC 2046 + msg = self._msgobj('msg_39.txt') + sfp = StringIO() + Iterators._structure(msg, sfp) + eq(sfp.getvalue(), """\ +multipart/mixed + multipart/mixed + multipart/alternative + application/octet-stream + application/octet-stream + text/plain +""") + + def test_boundary_in_non_multipart(self): + msg = self._msgobj('msg_40.txt') + self.assertEqual(msg.as_string(), '''\ +MIME-Version: 1.0 +Content-Type: text/html; boundary="--961284236552522269" + +----961284236552522269 +Content-Type: text/html; +Content-Transfer-Encoding: 7Bit + + + +----961284236552522269-- +''') + + def test_boundary_with_leading_space(self): + eq = self.assertEqual + msg = email.message_from_string('''\ +MIME-Version: 1.0 +Content-Type: multipart/mixed; boundary=" XXXX" + +-- XXXX +Content-Type: text/plain + + +-- XXXX +Content-Type: text/plain + +-- XXXX-- +''') + self.assertTrue(msg.is_multipart()) + eq(msg.get_boundary(), ' XXXX') + eq(len(msg.get_payload()), 2) + + def test_boundary_without_trailing_newline(self): + m = Parser().parsestr("""\ +Content-Type: multipart/mixed; boundary="===============0012394164==" +MIME-Version: 1.0 + +--===============0012394164== +Content-Type: image/file1.jpg +MIME-Version: 1.0 +Content-Transfer-Encoding: base64 + +YXNkZg== +--===============0012394164==--""") + self.assertEqual(m.get_payload(0).get_payload(), 'YXNkZg==') + + + +# Test some badly formatted messages +class TestNonConformant(TestEmailBase): + def test_parse_missing_minor_type(self): + eq = self.assertEqual + msg = self._msgobj('msg_14.txt') + eq(msg.get_content_type(), 'text/plain') + eq(msg.get_content_maintype(), 'text') + eq(msg.get_content_subtype(), 'plain') + + def test_same_boundary_inner_outer(self): + unless = self.assertTrue + msg = self._msgobj('msg_15.txt') + # XXX We can probably eventually do better + inner = msg.get_payload(0) + unless(hasattr(inner, 'defects')) + self.assertEqual(len(inner.defects), 1) + unless(isinstance(inner.defects[0], + Errors.StartBoundaryNotFoundDefect)) + + def test_multipart_no_boundary(self): + unless = self.assertTrue + msg = self._msgobj('msg_25.txt') + unless(isinstance(msg.get_payload(), str)) + self.assertEqual(len(msg.defects), 2) + unless(isinstance(msg.defects[0], Errors.NoBoundaryInMultipartDefect)) + unless(isinstance(msg.defects[1], + Errors.MultipartInvariantViolationDefect)) + + def test_invalid_content_type(self): + eq = self.assertEqual + neq = self.ndiffAssertEqual + msg = Message() + # RFC 2045, $5.2 says invalid yields text/plain + msg['Content-Type'] = 'text' + eq(msg.get_content_maintype(), 'text') + eq(msg.get_content_subtype(), 'plain') + eq(msg.get_content_type(), 'text/plain') + # Clear the old value and try something /really/ invalid + del msg['content-type'] + msg['Content-Type'] = 'foo' + eq(msg.get_content_maintype(), 'text') + eq(msg.get_content_subtype(), 'plain') + eq(msg.get_content_type(), 'text/plain') + # Still, make sure that the message is idempotently generated + s = StringIO() + g = Generator(s) + g.flatten(msg) + neq(s.getvalue(), 'Content-Type: foo\n\n') + + def test_no_start_boundary(self): + eq = self.ndiffAssertEqual + msg = self._msgobj('msg_31.txt') + eq(msg.get_payload(), """\ +--BOUNDARY +Content-Type: text/plain + +message 1 + +--BOUNDARY +Content-Type: text/plain + +message 2 + +--BOUNDARY-- +""") + + def test_no_separating_blank_line(self): + eq = self.ndiffAssertEqual + msg = self._msgobj('msg_35.txt') + eq(msg.as_string(), """\ +From: aperson at dom.ain +To: bperson at dom.ain +Subject: here's something interesting + +counter to RFC 2822, there's no separating newline here +""") + + def test_lying_multipart(self): + unless = self.assertTrue + msg = self._msgobj('msg_41.txt') + unless(hasattr(msg, 'defects')) + self.assertEqual(len(msg.defects), 2) + unless(isinstance(msg.defects[0], Errors.NoBoundaryInMultipartDefect)) + unless(isinstance(msg.defects[1], + Errors.MultipartInvariantViolationDefect)) + + def test_missing_start_boundary(self): + outer = self._msgobj('msg_42.txt') + # The message structure is: + # + # multipart/mixed + # text/plain + # message/rfc822 + # multipart/mixed [*] + # + # [*] This message is missing its start boundary + bad = outer.get_payload(1).get_payload(0) + self.assertEqual(len(bad.defects), 1) + self.assertTrue(isinstance(bad.defects[0], + Errors.StartBoundaryNotFoundDefect)) + + def test_first_line_is_continuation_header(self): + eq = self.assertEqual + m = ' Line 1\nLine 2\nLine 3' + msg = email.message_from_string(m) + eq(msg.keys(), []) + eq(msg.get_payload(), 'Line 2\nLine 3') + eq(len(msg.defects), 1) + self.assertTrue(isinstance(msg.defects[0], + Errors.FirstHeaderLineIsContinuationDefect)) + eq(msg.defects[0].line, ' Line 1\n') + + + + +# Test RFC 2047 header encoding and decoding +class TestRFC2047(unittest.TestCase): + def test_rfc2047_multiline(self): + eq = self.assertEqual + s = """Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz + foo bar =?mac-iceland?q?r=8Aksm=9Arg=8Cs?=""" + dh = decode_header(s) + eq(dh, [ + ('Re:', None), + ('r\x8aksm\x9arg\x8cs', 'mac-iceland'), + ('baz foo bar', None), + ('r\x8aksm\x9arg\x8cs', 'mac-iceland')]) + eq(str(make_header(dh)), + """Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz foo bar + =?mac-iceland?q?r=8Aksm=9Arg=8Cs?=""") + + def test_whitespace_eater_unicode(self): + eq = self.assertEqual + s = '=?ISO-8859-1?Q?Andr=E9?= Pirard ' + dh = decode_header(s) + eq(dh, [('Andr\xe9', 'iso-8859-1'), ('Pirard ', None)]) + hu = unicode(make_header(dh)).encode('latin-1') + eq(hu, 'Andr\xe9 Pirard ') + + def test_whitespace_eater_unicode_2(self): + eq = self.assertEqual + s = 'The =?iso-8859-1?b?cXVpY2sgYnJvd24gZm94?= jumped over the =?iso-8859-1?b?bGF6eSBkb2c=?=' + dh = decode_header(s) + eq(dh, [('The', None), ('quick brown fox', 'iso-8859-1'), + ('jumped over the', None), ('lazy dog', 'iso-8859-1')]) + hu = make_header(dh).__unicode__() + eq(hu, u'The quick brown fox jumped over the lazy dog') + + def test_rfc2047_without_whitespace(self): + s = 'Sm=?ISO-8859-1?B?9g==?=rg=?ISO-8859-1?B?5Q==?=sbord' + dh = decode_header(s) + self.assertEqual(dh, [(s, None)]) + + def test_rfc2047_with_whitespace(self): + s = 'Sm =?ISO-8859-1?B?9g==?= rg =?ISO-8859-1?B?5Q==?= sbord' + dh = decode_header(s) + self.assertEqual(dh, [('Sm', None), ('\xf6', 'iso-8859-1'), + ('rg', None), ('\xe5', 'iso-8859-1'), + ('sbord', None)]) + + def test_rfc2047_B_bad_padding(self): + s = '=?iso-8859-1?B?%s?=' + data = [ # only test complete bytes + ('dm==', 'v'), ('dm=', 'v'), ('dm', 'v'), + ('dmk=', 'vi'), ('dmk', 'vi') + ] + for q, a in data: + dh = decode_header(s % q) + self.assertEqual(dh, [(a, 'iso-8859-1')]) + + def test_rfc2047_Q_invalid_digits(self): + # issue 10004. + s = '=?iso-8659-1?Q?andr=e9=zz?=' + self.assertEqual(decode_header(s), + [(b'andr\xe9=zz', 'iso-8659-1')]) + + +# Test the MIMEMessage class +class TestMIMEMessage(TestEmailBase): + def setUp(self): + fp = openfile('msg_11.txt') + try: + self._text = fp.read() + finally: + fp.close() + + def test_type_error(self): + self.assertRaises(TypeError, MIMEMessage, 'a plain string') + + def test_valid_argument(self): + eq = self.assertEqual + unless = self.assertTrue + subject = 'A sub-message' + m = Message() + m['Subject'] = subject + r = MIMEMessage(m) + eq(r.get_content_type(), 'message/rfc822') + payload = r.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + subpart = payload[0] + unless(subpart is m) + eq(subpart['subject'], subject) + + def test_bad_multipart(self): + eq = self.assertEqual + msg1 = Message() + msg1['Subject'] = 'subpart 1' + msg2 = Message() + msg2['Subject'] = 'subpart 2' + r = MIMEMessage(msg1) + self.assertRaises(Errors.MultipartConversionError, r.attach, msg2) + + def test_generate(self): + # First craft the message to be encapsulated + m = Message() + m['Subject'] = 'An enclosed message' + m.set_payload('Here is the body of the message.\n') + r = MIMEMessage(m) + r['Subject'] = 'The enclosing message' + s = StringIO() + g = Generator(s) + g.flatten(r) + self.assertEqual(s.getvalue(), """\ +Content-Type: message/rfc822 +MIME-Version: 1.0 +Subject: The enclosing message + +Subject: An enclosed message + +Here is the body of the message. +""") + + def test_parse_message_rfc822(self): + eq = self.assertEqual + unless = self.assertTrue + msg = self._msgobj('msg_11.txt') + eq(msg.get_content_type(), 'message/rfc822') + payload = msg.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + submsg = payload[0] + self.assertTrue(isinstance(submsg, Message)) + eq(submsg['subject'], 'An enclosed message') + eq(submsg.get_payload(), 'Here is the body of the message.\n') + + def test_dsn(self): + eq = self.assertEqual + unless = self.assertTrue + # msg 16 is a Delivery Status Notification, see RFC 1894 + msg = self._msgobj('msg_16.txt') + eq(msg.get_content_type(), 'multipart/report') + unless(msg.is_multipart()) + eq(len(msg.get_payload()), 3) + # Subpart 1 is a text/plain, human readable section + subpart = msg.get_payload(0) + eq(subpart.get_content_type(), 'text/plain') + eq(subpart.get_payload(), """\ +This report relates to a message you sent with the following header fields: + + Message-id: <002001c144a6$8752e060$56104586 at oxy.edu> + Date: Sun, 23 Sep 2001 20:10:55 -0700 + From: "Ian T. Henry" + To: SoCal Raves + Subject: [scr] yeah for Ians!! + +Your message cannot be delivered to the following recipients: + + Recipient address: jangel1 at cougar.noc.ucla.edu + Reason: recipient reached disk quota + +""") + # Subpart 2 contains the machine parsable DSN information. It + # consists of two blocks of headers, represented by two nested Message + # objects. + subpart = msg.get_payload(1) + eq(subpart.get_content_type(), 'message/delivery-status') + eq(len(subpart.get_payload()), 2) + # message/delivery-status should treat each block as a bunch of + # headers, i.e. a bunch of Message objects. + dsn1 = subpart.get_payload(0) + unless(isinstance(dsn1, Message)) + eq(dsn1['original-envelope-id'], '0GK500B4HD0888 at cougar.noc.ucla.edu') + eq(dsn1.get_param('dns', header='reporting-mta'), '') + # Try a missing one + eq(dsn1.get_param('nsd', header='reporting-mta'), None) + dsn2 = subpart.get_payload(1) + unless(isinstance(dsn2, Message)) + eq(dsn2['action'], 'failed') + eq(dsn2.get_params(header='original-recipient'), + [('rfc822', ''), ('jangel1 at cougar.noc.ucla.edu', '')]) + eq(dsn2.get_param('rfc822', header='final-recipient'), '') + # Subpart 3 is the original message + subpart = msg.get_payload(2) + eq(subpart.get_content_type(), 'message/rfc822') + payload = subpart.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + subsubpart = payload[0] + unless(isinstance(subsubpart, Message)) + eq(subsubpart.get_content_type(), 'text/plain') + eq(subsubpart['message-id'], + '<002001c144a6$8752e060$56104586 at oxy.edu>') + + def test_epilogue(self): + eq = self.ndiffAssertEqual + fp = openfile('msg_21.txt') + try: + text = fp.read() + finally: + fp.close() + msg = Message() + msg['From'] = 'aperson at dom.ain' + msg['To'] = 'bperson at dom.ain' + msg['Subject'] = 'Test' + msg.preamble = 'MIME message' + msg.epilogue = 'End of MIME message\n' + msg1 = MIMEText('One') + msg2 = MIMEText('Two') + msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY') + msg.attach(msg1) + msg.attach(msg2) + sfp = StringIO() + g = Generator(sfp) + g.flatten(msg) + eq(sfp.getvalue(), text) + + def test_no_nl_preamble(self): + eq = self.ndiffAssertEqual + msg = Message() + msg['From'] = 'aperson at dom.ain' + msg['To'] = 'bperson at dom.ain' + msg['Subject'] = 'Test' + msg.preamble = 'MIME message' + msg.epilogue = '' + msg1 = MIMEText('One') + msg2 = MIMEText('Two') + msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY') + msg.attach(msg1) + msg.attach(msg2) + eq(msg.as_string(), """\ +From: aperson at dom.ain +To: bperson at dom.ain +Subject: Test +Content-Type: multipart/mixed; boundary="BOUNDARY" + +MIME message +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +One +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +Two +--BOUNDARY-- +""") + + def test_default_type(self): + eq = self.assertEqual + fp = openfile('msg_30.txt') + try: + msg = email.message_from_file(fp) + finally: + fp.close() + container1 = msg.get_payload(0) + eq(container1.get_default_type(), 'message/rfc822') + eq(container1.get_content_type(), 'message/rfc822') + container2 = msg.get_payload(1) + eq(container2.get_default_type(), 'message/rfc822') + eq(container2.get_content_type(), 'message/rfc822') + container1a = container1.get_payload(0) + eq(container1a.get_default_type(), 'text/plain') + eq(container1a.get_content_type(), 'text/plain') + container2a = container2.get_payload(0) + eq(container2a.get_default_type(), 'text/plain') + eq(container2a.get_content_type(), 'text/plain') + + def test_default_type_with_explicit_container_type(self): + eq = self.assertEqual + fp = openfile('msg_28.txt') + try: + msg = email.message_from_file(fp) + finally: + fp.close() + container1 = msg.get_payload(0) + eq(container1.get_default_type(), 'message/rfc822') + eq(container1.get_content_type(), 'message/rfc822') + container2 = msg.get_payload(1) + eq(container2.get_default_type(), 'message/rfc822') + eq(container2.get_content_type(), 'message/rfc822') + container1a = container1.get_payload(0) + eq(container1a.get_default_type(), 'text/plain') + eq(container1a.get_content_type(), 'text/plain') + container2a = container2.get_payload(0) + eq(container2a.get_default_type(), 'text/plain') + eq(container2a.get_content_type(), 'text/plain') + + def test_default_type_non_parsed(self): + eq = self.assertEqual + neq = self.ndiffAssertEqual + # Set up container + container = MIMEMultipart('digest', 'BOUNDARY') + container.epilogue = '' + # Set up subparts + subpart1a = MIMEText('message 1\n') + subpart2a = MIMEText('message 2\n') + subpart1 = MIMEMessage(subpart1a) + subpart2 = MIMEMessage(subpart2a) + container.attach(subpart1) + container.attach(subpart2) + eq(subpart1.get_content_type(), 'message/rfc822') + eq(subpart1.get_default_type(), 'message/rfc822') + eq(subpart2.get_content_type(), 'message/rfc822') + eq(subpart2.get_default_type(), 'message/rfc822') + neq(container.as_string(0), '''\ +Content-Type: multipart/digest; boundary="BOUNDARY" +MIME-Version: 1.0 + +--BOUNDARY +Content-Type: message/rfc822 +MIME-Version: 1.0 + +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +message 1 + +--BOUNDARY +Content-Type: message/rfc822 +MIME-Version: 1.0 + +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +message 2 + +--BOUNDARY-- +''') + del subpart1['content-type'] + del subpart1['mime-version'] + del subpart2['content-type'] + del subpart2['mime-version'] + eq(subpart1.get_content_type(), 'message/rfc822') + eq(subpart1.get_default_type(), 'message/rfc822') + eq(subpart2.get_content_type(), 'message/rfc822') + eq(subpart2.get_default_type(), 'message/rfc822') + neq(container.as_string(0), '''\ +Content-Type: multipart/digest; boundary="BOUNDARY" +MIME-Version: 1.0 + +--BOUNDARY + +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +message 1 + +--BOUNDARY + +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +message 2 + +--BOUNDARY-- +''') + + def test_mime_attachments_in_constructor(self): + eq = self.assertEqual + text1 = MIMEText('') + text2 = MIMEText('') + msg = MIMEMultipart(_subparts=(text1, text2)) + eq(len(msg.get_payload()), 2) + eq(msg.get_payload(0), text1) + eq(msg.get_payload(1), text2) + + def test_default_multipart_constructor(self): + msg = MIMEMultipart() + self.assertTrue(msg.is_multipart()) + + +# A general test of parser->model->generator idempotency. IOW, read a message +# in, parse it into a message object tree, then without touching the tree, +# regenerate the plain text. The original text and the transformed text +# should be identical. Note: that we ignore the Unix-From since that may +# contain a changed date. +class TestIdempotent(TestEmailBase): + def _msgobj(self, filename): + fp = openfile(filename) + try: + data = fp.read() + finally: + fp.close() + msg = email.message_from_string(data) + return msg, data + + def _idempotent(self, msg, text): + eq = self.ndiffAssertEqual + s = StringIO() + g = Generator(s, maxheaderlen=0) + g.flatten(msg) + eq(text, s.getvalue()) + + def test_parse_text_message(self): + eq = self.assertEqual + msg, text = self._msgobj('msg_01.txt') + eq(msg.get_content_type(), 'text/plain') + eq(msg.get_content_maintype(), 'text') + eq(msg.get_content_subtype(), 'plain') + eq(msg.get_params()[1], ('charset', 'us-ascii')) + eq(msg.get_param('charset'), 'us-ascii') + eq(msg.preamble, None) + eq(msg.epilogue, None) + self._idempotent(msg, text) + + def test_parse_untyped_message(self): + eq = self.assertEqual + msg, text = self._msgobj('msg_03.txt') + eq(msg.get_content_type(), 'text/plain') + eq(msg.get_params(), None) + eq(msg.get_param('charset'), None) + self._idempotent(msg, text) + + def test_simple_multipart(self): + msg, text = self._msgobj('msg_04.txt') + self._idempotent(msg, text) + + def test_MIME_digest(self): + msg, text = self._msgobj('msg_02.txt') + self._idempotent(msg, text) + + def test_long_header(self): + msg, text = self._msgobj('msg_27.txt') + self._idempotent(msg, text) + + def test_MIME_digest_with_part_headers(self): + msg, text = self._msgobj('msg_28.txt') + self._idempotent(msg, text) + + def test_mixed_with_image(self): + msg, text = self._msgobj('msg_06.txt') + self._idempotent(msg, text) + + def test_multipart_report(self): + msg, text = self._msgobj('msg_05.txt') + self._idempotent(msg, text) + + def test_dsn(self): + msg, text = self._msgobj('msg_16.txt') + self._idempotent(msg, text) + + def test_preamble_epilogue(self): + msg, text = self._msgobj('msg_21.txt') + self._idempotent(msg, text) + + def test_multipart_one_part(self): + msg, text = self._msgobj('msg_23.txt') + self._idempotent(msg, text) + + def test_multipart_no_parts(self): + msg, text = self._msgobj('msg_24.txt') + self._idempotent(msg, text) + + def test_no_start_boundary(self): + msg, text = self._msgobj('msg_31.txt') + self._idempotent(msg, text) + + def test_rfc2231_charset(self): + msg, text = self._msgobj('msg_32.txt') + self._idempotent(msg, text) + + def test_more_rfc2231_parameters(self): + msg, text = self._msgobj('msg_33.txt') + self._idempotent(msg, text) + + def test_text_plain_in_a_multipart_digest(self): + msg, text = self._msgobj('msg_34.txt') + self._idempotent(msg, text) + + def test_nested_multipart_mixeds(self): + msg, text = self._msgobj('msg_12a.txt') + self._idempotent(msg, text) + + def test_message_external_body_idempotent(self): + msg, text = self._msgobj('msg_36.txt') + self._idempotent(msg, text) + + def test_content_type(self): + eq = self.assertEqual + unless = self.assertTrue + # Get a message object and reset the seek pointer for other tests + msg, text = self._msgobj('msg_05.txt') + eq(msg.get_content_type(), 'multipart/report') + # Test the Content-Type: parameters + params = {} + for pk, pv in msg.get_params(): + params[pk] = pv + eq(params['report-type'], 'delivery-status') + eq(params['boundary'], 'D1690A7AC1.996856090/mail.example.com') + eq(msg.preamble, 'This is a MIME-encapsulated message.\n') + eq(msg.epilogue, '\n') + eq(len(msg.get_payload()), 3) + # Make sure the subparts are what we expect + msg1 = msg.get_payload(0) + eq(msg1.get_content_type(), 'text/plain') + eq(msg1.get_payload(), 'Yadda yadda yadda\n') + msg2 = msg.get_payload(1) + eq(msg2.get_content_type(), 'text/plain') + eq(msg2.get_payload(), 'Yadda yadda yadda\n') + msg3 = msg.get_payload(2) + eq(msg3.get_content_type(), 'message/rfc822') + self.assertTrue(isinstance(msg3, Message)) + payload = msg3.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + msg4 = payload[0] + unless(isinstance(msg4, Message)) + eq(msg4.get_payload(), 'Yadda yadda yadda\n') + + def test_parser(self): + eq = self.assertEqual + unless = self.assertTrue + msg, text = self._msgobj('msg_06.txt') + # Check some of the outer headers + eq(msg.get_content_type(), 'message/rfc822') + # Make sure the payload is a list of exactly one sub-Message, and that + # that submessage has a type of text/plain + payload = msg.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + msg1 = payload[0] + self.assertTrue(isinstance(msg1, Message)) + eq(msg1.get_content_type(), 'text/plain') + self.assertTrue(isinstance(msg1.get_payload(), str)) + eq(msg1.get_payload(), '\n') + + + +# Test various other bits of the package's functionality +class TestMiscellaneous(TestEmailBase): + def test_message_from_string(self): + fp = openfile('msg_01.txt') + try: + text = fp.read() + finally: + fp.close() + msg = email.message_from_string(text) + s = StringIO() + # Don't wrap/continue long headers since we're trying to test + # idempotency. + g = Generator(s, maxheaderlen=0) + g.flatten(msg) + self.assertEqual(text, s.getvalue()) + + def test_message_from_file(self): + fp = openfile('msg_01.txt') + try: + text = fp.read() + fp.seek(0) + msg = email.message_from_file(fp) + s = StringIO() + # Don't wrap/continue long headers since we're trying to test + # idempotency. + g = Generator(s, maxheaderlen=0) + g.flatten(msg) + self.assertEqual(text, s.getvalue()) + finally: + fp.close() + + def test_message_from_string_with_class(self): + unless = self.assertTrue + fp = openfile('msg_01.txt') + try: + text = fp.read() + finally: + fp.close() + # Create a subclass + class MyMessage(Message): + pass + + msg = email.message_from_string(text, MyMessage) + unless(isinstance(msg, MyMessage)) + # Try something more complicated + fp = openfile('msg_02.txt') + try: + text = fp.read() + finally: + fp.close() + msg = email.message_from_string(text, MyMessage) + for subpart in msg.walk(): + unless(isinstance(subpart, MyMessage)) + + def test_message_from_file_with_class(self): + unless = self.assertTrue + # Create a subclass + class MyMessage(Message): + pass + + fp = openfile('msg_01.txt') + try: + msg = email.message_from_file(fp, MyMessage) + finally: + fp.close() + unless(isinstance(msg, MyMessage)) + # Try something more complicated + fp = openfile('msg_02.txt') + try: + msg = email.message_from_file(fp, MyMessage) + finally: + fp.close() + for subpart in msg.walk(): + unless(isinstance(subpart, MyMessage)) + + def test__all__(self): + module = __import__('email') + all = module.__all__ + all.sort() + self.assertEqual(all, [ + # Old names + 'Charset', 'Encoders', 'Errors', 'Generator', + 'Header', 'Iterators', 'MIMEAudio', 'MIMEBase', + 'MIMEImage', 'MIMEMessage', 'MIMEMultipart', + 'MIMENonMultipart', 'MIMEText', 'Message', + 'Parser', 'Utils', 'base64MIME', + # new names + 'base64mime', 'charset', 'encoders', 'errors', 'generator', + 'header', 'iterators', 'message', 'message_from_file', + 'message_from_string', 'mime', 'parser', + 'quopriMIME', 'quoprimime', 'utils', + ]) + + def test_formatdate(self): + now = time.time() + self.assertEqual(Utils.parsedate(Utils.formatdate(now))[:6], + time.gmtime(now)[:6]) + + def test_formatdate_localtime(self): + now = time.time() + self.assertEqual( + Utils.parsedate(Utils.formatdate(now, localtime=True))[:6], + time.localtime(now)[:6]) + + def test_formatdate_usegmt(self): + now = time.time() + self.assertEqual( + Utils.formatdate(now, localtime=False), + time.strftime('%a, %d %b %Y %H:%M:%S -0000', time.gmtime(now))) + self.assertEqual( + Utils.formatdate(now, localtime=False, usegmt=True), + time.strftime('%a, %d %b %Y %H:%M:%S GMT', time.gmtime(now))) + + def test_parsedate_none(self): + self.assertEqual(Utils.parsedate(''), None) + + def test_parsedate_compact(self): + # The FWS after the comma is optional + self.assertEqual(Utils.parsedate('Wed,3 Apr 2002 14:58:26 +0800'), + Utils.parsedate('Wed, 3 Apr 2002 14:58:26 +0800')) + + def test_parsedate_no_dayofweek(self): + eq = self.assertEqual + eq(Utils.parsedate_tz('25 Feb 2003 13:47:26 -0800'), + (2003, 2, 25, 13, 47, 26, 0, 1, -1, -28800)) + + def test_parsedate_compact_no_dayofweek(self): + eq = self.assertEqual + eq(Utils.parsedate_tz('5 Feb 2003 13:47:26 -0800'), + (2003, 2, 5, 13, 47, 26, 0, 1, -1, -28800)) + + def test_parsedate_acceptable_to_time_functions(self): + eq = self.assertEqual + timetup = Utils.parsedate('5 Feb 2003 13:47:26 -0800') + t = int(time.mktime(timetup)) + eq(time.localtime(t)[:6], timetup[:6]) + eq(int(time.strftime('%Y', timetup)), 2003) + timetup = Utils.parsedate_tz('5 Feb 2003 13:47:26 -0800') + t = int(time.mktime(timetup[:9])) + eq(time.localtime(t)[:6], timetup[:6]) + eq(int(time.strftime('%Y', timetup[:9])), 2003) + + def test_mktime_tz(self): + self.assertEqual(Utils.mktime_tz((1970, 1, 1, 0, 0, 0, + -1, -1, -1, 0)), 0) + self.assertEqual(Utils.mktime_tz((1970, 1, 1, 0, 0, 0, + -1, -1, -1, 1234)), -1234) + + def test_parsedate_y2k(self): + """Test for parsing a date with a two-digit year. + + Parsing a date with a two-digit year should return the correct + four-digit year. RFC822 allows two-digit years, but RFC2822 (which + obsoletes RFC822) requires four-digit years. + + """ + self.assertEqual(Utils.parsedate_tz('25 Feb 03 13:47:26 -0800'), + Utils.parsedate_tz('25 Feb 2003 13:47:26 -0800')) + self.assertEqual(Utils.parsedate_tz('25 Feb 71 13:47:26 -0800'), + Utils.parsedate_tz('25 Feb 1971 13:47:26 -0800')) + + def test_parseaddr_empty(self): + self.assertEqual(Utils.parseaddr('<>'), ('', '')) + self.assertEqual(Utils.formataddr(Utils.parseaddr('<>')), '') + + def test_noquote_dump(self): + self.assertEqual( + Utils.formataddr(('A Silly Person', 'person at dom.ain')), + 'A Silly Person ') + + def test_escape_dump(self): + self.assertEqual( + Utils.formataddr(('A (Very) Silly Person', 'person at dom.ain')), + r'"A \(Very\) Silly Person" ') + a = r'A \(Special\) Person' + b = 'person at dom.ain' + self.assertEqual(Utils.parseaddr(Utils.formataddr((a, b))), (a, b)) + + def test_escape_backslashes(self): + self.assertEqual( + Utils.formataddr(('Arthur \Backslash\ Foobar', 'person at dom.ain')), + r'"Arthur \\Backslash\\ Foobar" ') + a = r'Arthur \Backslash\ Foobar' + b = 'person at dom.ain' + self.assertEqual(Utils.parseaddr(Utils.formataddr((a, b))), (a, b)) + + def test_name_with_dot(self): + x = 'John X. Doe ' + y = '"John X. Doe" ' + a, b = ('John X. Doe', 'jxd at example.com') + self.assertEqual(Utils.parseaddr(x), (a, b)) + self.assertEqual(Utils.parseaddr(y), (a, b)) + # formataddr() quotes the name if there's a dot in it + self.assertEqual(Utils.formataddr((a, b)), y) + + def test_parseaddr_preserves_quoted_pairs_in_addresses(self): + # issue 10005. Note that in the third test the second pair of + # backslashes is not actually a quoted pair because it is not inside a + # comment or quoted string: the address being parsed has a quoted + # string containing a quoted backslash, followed by 'example' and two + # backslashes, followed by another quoted string containing a space and + # the word 'example'. parseaddr copies those two backslashes + # literally. Per rfc5322 this is not technically correct since a \ may + # not appear in an address outside of a quoted string. It is probably + # a sensible Postel interpretation, though. + eq = self.assertEqual + eq(Utils.parseaddr('""example" example"@example.com'), + ('', '""example" example"@example.com')) + eq(Utils.parseaddr('"\\"example\\" example"@example.com'), + ('', '"\\"example\\" example"@example.com')) + eq(Utils.parseaddr('"\\\\"example\\\\" example"@example.com'), + ('', '"\\\\"example\\\\" example"@example.com')) + + def test_multiline_from_comment(self): + x = """\ +Foo +\tBar """ + self.assertEqual(Utils.parseaddr(x), ('Foo Bar', 'foo at example.com')) + + def test_quote_dump(self): + self.assertEqual( + Utils.formataddr(('A Silly; Person', 'person at dom.ain')), + r'"A Silly; Person" ') + + def test_fix_eols(self): + eq = self.assertEqual + eq(Utils.fix_eols('hello'), 'hello') + eq(Utils.fix_eols('hello\n'), 'hello\r\n') + eq(Utils.fix_eols('hello\r'), 'hello\r\n') + eq(Utils.fix_eols('hello\r\n'), 'hello\r\n') + eq(Utils.fix_eols('hello\n\r'), 'hello\r\n\r\n') + + def test_charset_richcomparisons(self): + eq = self.assertEqual + ne = self.assertNotEqual + cset1 = Charset() + cset2 = Charset() + eq(cset1, 'us-ascii') + eq(cset1, 'US-ASCII') + eq(cset1, 'Us-AsCiI') + eq('us-ascii', cset1) + eq('US-ASCII', cset1) + eq('Us-AsCiI', cset1) + ne(cset1, 'usascii') + ne(cset1, 'USASCII') + ne(cset1, 'UsAsCiI') + ne('usascii', cset1) + ne('USASCII', cset1) + ne('UsAsCiI', cset1) + eq(cset1, cset2) + eq(cset2, cset1) + + def test_getaddresses(self): + eq = self.assertEqual + eq(Utils.getaddresses(['aperson at dom.ain (Al Person)', + 'Bud Person ']), + [('Al Person', 'aperson at dom.ain'), + ('Bud Person', 'bperson at dom.ain')]) + + def test_getaddresses_nasty(self): + eq = self.assertEqual + eq(Utils.getaddresses(['foo: ;']), [('', '')]) + eq(Utils.getaddresses( + ['[]*-- =~$']), + [('', ''), ('', ''), ('', '*--')]) + eq(Utils.getaddresses( + ['foo: ;', '"Jason R. Mastaler" ']), + [('', ''), ('Jason R. Mastaler', 'jason at dom.ain')]) + + def test_getaddresses_embedded_comment(self): + """Test proper handling of a nested comment""" + eq = self.assertEqual + addrs = Utils.getaddresses(['User ((nested comment)) ']) + eq(addrs[0][1], 'foo at bar.com') + + def test_utils_quote_unquote(self): + eq = self.assertEqual + msg = Message() + msg.add_header('content-disposition', 'attachment', + filename='foo\\wacky"name') + eq(msg.get_filename(), 'foo\\wacky"name') + + def test_get_body_encoding_with_bogus_charset(self): + charset = Charset('not a charset') + self.assertEqual(charset.get_body_encoding(), 'base64') + + def test_get_body_encoding_with_uppercase_charset(self): + eq = self.assertEqual + msg = Message() + msg['Content-Type'] = 'text/plain; charset=UTF-8' + eq(msg['content-type'], 'text/plain; charset=UTF-8') + charsets = msg.get_charsets() + eq(len(charsets), 1) + eq(charsets[0], 'utf-8') + charset = Charset(charsets[0]) + eq(charset.get_body_encoding(), 'base64') + msg.set_payload('hello world', charset=charset) + eq(msg.get_payload(), 'aGVsbG8gd29ybGQ=\n') + eq(msg.get_payload(decode=True), 'hello world') + eq(msg['content-transfer-encoding'], 'base64') + # Try another one + msg = Message() + msg['Content-Type'] = 'text/plain; charset="US-ASCII"' + charsets = msg.get_charsets() + eq(len(charsets), 1) + eq(charsets[0], 'us-ascii') + charset = Charset(charsets[0]) + eq(charset.get_body_encoding(), Encoders.encode_7or8bit) + msg.set_payload('hello world', charset=charset) + eq(msg.get_payload(), 'hello world') + eq(msg['content-transfer-encoding'], '7bit') + + def test_charsets_case_insensitive(self): + lc = Charset('us-ascii') + uc = Charset('US-ASCII') + self.assertEqual(lc.get_body_encoding(), uc.get_body_encoding()) + + def test_partial_falls_inside_message_delivery_status(self): + eq = self.ndiffAssertEqual + # The Parser interface provides chunks of data to FeedParser in 8192 + # byte gulps. SF bug #1076485 found one of those chunks inside + # message/delivery-status header block, which triggered an + # unreadline() of NeedMoreData. + msg = self._msgobj('msg_43.txt') + sfp = StringIO() + Iterators._structure(msg, sfp) + eq(sfp.getvalue(), """\ +multipart/report + text/plain + message/delivery-status + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/rfc822-headers +""") + + + +# Test the iterator/generators +class TestIterators(TestEmailBase): + def test_body_line_iterator(self): + eq = self.assertEqual + neq = self.ndiffAssertEqual + # First a simple non-multipart message + msg = self._msgobj('msg_01.txt') + it = Iterators.body_line_iterator(msg) + lines = list(it) + eq(len(lines), 6) + neq(EMPTYSTRING.join(lines), msg.get_payload()) + # Now a more complicated multipart + msg = self._msgobj('msg_02.txt') + it = Iterators.body_line_iterator(msg) + lines = list(it) + eq(len(lines), 43) + fp = openfile('msg_19.txt') + try: + neq(EMPTYSTRING.join(lines), fp.read()) + finally: + fp.close() + + def test_typed_subpart_iterator(self): + eq = self.assertEqual + msg = self._msgobj('msg_04.txt') + it = Iterators.typed_subpart_iterator(msg, 'text') + lines = [] + subparts = 0 + for subpart in it: + subparts += 1 + lines.append(subpart.get_payload()) + eq(subparts, 2) + eq(EMPTYSTRING.join(lines), """\ +a simple kind of mirror +to reflect upon our own +a simple kind of mirror +to reflect upon our own +""") + + def test_typed_subpart_iterator_default_type(self): + eq = self.assertEqual + msg = self._msgobj('msg_03.txt') + it = Iterators.typed_subpart_iterator(msg, 'text', 'plain') + lines = [] + subparts = 0 + for subpart in it: + subparts += 1 + lines.append(subpart.get_payload()) + eq(subparts, 1) + eq(EMPTYSTRING.join(lines), """\ + +Hi, + +Do you like this message? + +-Me +""") + + def test_pushCR_LF(self): + '''FeedParser BufferedSubFile.push() assumed it received complete + line endings. A CR ending one push() followed by a LF starting + the next push() added an empty line. + ''' + imt = [ + ("a\r \n", 2), + ("b", 0), + ("c\n", 1), + ("", 0), + ("d\r\n", 1), + ("e\r", 0), + ("\nf", 1), + ("\r\n", 1), + ] + from email.feedparser import BufferedSubFile, NeedMoreData + bsf = BufferedSubFile() + om = [] + nt = 0 + for il, n in imt: + bsf.push(il) + nt += n + n1 = 0 + while True: + ol = bsf.readline() + if ol == NeedMoreData: + break + om.append(ol) + n1 += 1 + self.assertTrue(n == n1) + self.assertTrue(len(om) == nt) + self.assertTrue(''.join([il for il, n in imt]) == ''.join(om)) + + + +class TestParsers(TestEmailBase): + def test_header_parser(self): + eq = self.assertEqual + # Parse only the headers of a complex multipart MIME document + fp = openfile('msg_02.txt') + try: + msg = HeaderParser().parse(fp) + finally: + fp.close() + eq(msg['from'], 'ppp-request at zzz.org') + eq(msg['to'], 'ppp at zzz.org') + eq(msg.get_content_type(), 'multipart/mixed') + self.assertFalse(msg.is_multipart()) + self.assertTrue(isinstance(msg.get_payload(), str)) + + def test_whitespace_continuation(self): + eq = self.assertEqual + # This message contains a line after the Subject: header that has only + # whitespace, but it is not empty! + msg = email.message_from_string("""\ +From: aperson at dom.ain +To: bperson at dom.ain +Subject: the next line has a space on it +\x20 +Date: Mon, 8 Apr 2002 15:09:19 -0400 +Message-ID: spam + +Here's the message body +""") + eq(msg['subject'], 'the next line has a space on it\n ') + eq(msg['message-id'], 'spam') + eq(msg.get_payload(), "Here's the message body\n") + + def test_whitespace_continuation_last_header(self): + eq = self.assertEqual + # Like the previous test, but the subject line is the last + # header. + msg = email.message_from_string("""\ +From: aperson at dom.ain +To: bperson at dom.ain +Date: Mon, 8 Apr 2002 15:09:19 -0400 +Message-ID: spam +Subject: the next line has a space on it +\x20 + +Here's the message body +""") + eq(msg['subject'], 'the next line has a space on it\n ') + eq(msg['message-id'], 'spam') + eq(msg.get_payload(), "Here's the message body\n") + + def test_crlf_separation(self): + eq = self.assertEqual + fp = openfile('msg_26.txt', mode='rb') + try: + msg = Parser().parse(fp) + finally: + fp.close() + eq(len(msg.get_payload()), 2) + part1 = msg.get_payload(0) + eq(part1.get_content_type(), 'text/plain') + eq(part1.get_payload(), 'Simple email with attachment.\r\n\r\n') + part2 = msg.get_payload(1) + eq(part2.get_content_type(), 'application/riscos') + + def test_multipart_digest_with_extra_mime_headers(self): + eq = self.assertEqual + neq = self.ndiffAssertEqual + fp = openfile('msg_28.txt') + try: + msg = email.message_from_file(fp) + finally: + fp.close() + # Structure is: + # multipart/digest + # message/rfc822 + # text/plain + # message/rfc822 + # text/plain + eq(msg.is_multipart(), 1) + eq(len(msg.get_payload()), 2) + part1 = msg.get_payload(0) + eq(part1.get_content_type(), 'message/rfc822') + eq(part1.is_multipart(), 1) + eq(len(part1.get_payload()), 1) + part1a = part1.get_payload(0) + eq(part1a.is_multipart(), 0) + eq(part1a.get_content_type(), 'text/plain') + neq(part1a.get_payload(), 'message 1\n') + # next message/rfc822 + part2 = msg.get_payload(1) + eq(part2.get_content_type(), 'message/rfc822') + eq(part2.is_multipart(), 1) + eq(len(part2.get_payload()), 1) + part2a = part2.get_payload(0) + eq(part2a.is_multipart(), 0) + eq(part2a.get_content_type(), 'text/plain') + neq(part2a.get_payload(), 'message 2\n') + + def test_three_lines(self): + # A bug report by Andrew McNamara + lines = ['From: Andrew Person From', 'From']) + eq(msg.get_payload(), 'body') + + def test_rfc2822_space_not_allowed_in_header(self): + eq = self.assertEqual + m = '>From foo at example.com 11:25:53\nFrom: bar\n!"#QUX;~: zoo\n\nbody' + msg = email.message_from_string(m) + eq(len(msg.keys()), 0) + + def test_rfc2822_one_character_header(self): + eq = self.assertEqual + m = 'A: first header\nB: second header\nCC: third header\n\nbody' + msg = email.message_from_string(m) + headers = msg.keys() + headers.sort() + eq(headers, ['A', 'B', 'CC']) + eq(msg.get_payload(), 'body') + + def test_CRLFLF_at_end_of_part(self): + # issue 5610: feedparser should not eat two chars from body part ending + # with "\r\n\n". + m = ( + "From: foo at bar.com\n" + "To: baz\n" + "Mime-Version: 1.0\n" + "Content-Type: multipart/mixed; boundary=BOUNDARY\n" + "\n" + "--BOUNDARY\n" + "Content-Type: text/plain\n" + "\n" + "body ending with CRLF newline\r\n" + "\n" + "--BOUNDARY--\n" + ) + msg = email.message_from_string(m) + self.assertTrue(msg.get_payload(0).get_payload().endswith('\r\n')) + + +class TestBase64(unittest.TestCase): + def test_len(self): + eq = self.assertEqual + eq(base64MIME.base64_len('hello'), + len(base64MIME.encode('hello', eol=''))) + for size in range(15): + if size == 0 : bsize = 0 + elif size <= 3 : bsize = 4 + elif size <= 6 : bsize = 8 + elif size <= 9 : bsize = 12 + elif size <= 12: bsize = 16 + else : bsize = 20 + eq(base64MIME.base64_len('x'*size), bsize) + + def test_decode(self): + eq = self.assertEqual + eq(base64MIME.decode(''), '') + eq(base64MIME.decode('aGVsbG8='), 'hello') + eq(base64MIME.decode('aGVsbG8=', 'X'), 'hello') + eq(base64MIME.decode('aGVsbG8NCndvcmxk\n', 'X'), 'helloXworld') + + def test_encode(self): + eq = self.assertEqual + eq(base64MIME.encode(''), '') + eq(base64MIME.encode('hello'), 'aGVsbG8=\n') + # Test the binary flag + eq(base64MIME.encode('hello\n'), 'aGVsbG8K\n') + eq(base64MIME.encode('hello\n', 0), 'aGVsbG8NCg==\n') + # Test the maxlinelen arg + eq(base64MIME.encode('xxxx ' * 20, maxlinelen=40), """\ +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg +eHh4eCB4eHh4IA== +""") + # Test the eol argument + eq(base64MIME.encode('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\ +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r +eHh4eCB4eHh4IA==\r +""") + + def test_header_encode(self): + eq = self.assertEqual + he = base64MIME.header_encode + eq(he('hello'), '=?iso-8859-1?b?aGVsbG8=?=') + eq(he('hello\nworld'), '=?iso-8859-1?b?aGVsbG8NCndvcmxk?=') + # Test the charset option + eq(he('hello', charset='iso-8859-2'), '=?iso-8859-2?b?aGVsbG8=?=') + # Test the keep_eols flag + eq(he('hello\nworld', keep_eols=True), + '=?iso-8859-1?b?aGVsbG8Kd29ybGQ=?=') + # Test the maxlinelen argument + eq(he('xxxx ' * 20, maxlinelen=40), """\ +=?iso-8859-1?b?eHh4eCB4eHh4IHh4eHggeHg=?= + =?iso-8859-1?b?eHggeHh4eCB4eHh4IHh4eHg=?= + =?iso-8859-1?b?IHh4eHggeHh4eCB4eHh4IHg=?= + =?iso-8859-1?b?eHh4IHh4eHggeHh4eCB4eHg=?= + =?iso-8859-1?b?eCB4eHh4IHh4eHggeHh4eCA=?= + =?iso-8859-1?b?eHh4eCB4eHh4IHh4eHgg?=""") + # Test the eol argument + eq(he('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\ +=?iso-8859-1?b?eHh4eCB4eHh4IHh4eHggeHg=?=\r + =?iso-8859-1?b?eHggeHh4eCB4eHh4IHh4eHg=?=\r + =?iso-8859-1?b?IHh4eHggeHh4eCB4eHh4IHg=?=\r + =?iso-8859-1?b?eHh4IHh4eHggeHh4eCB4eHg=?=\r + =?iso-8859-1?b?eCB4eHh4IHh4eHggeHh4eCA=?=\r + =?iso-8859-1?b?eHh4eCB4eHh4IHh4eHgg?=""") + + + +class TestQuopri(unittest.TestCase): + def setUp(self): + self.hlit = [chr(x) for x in range(ord('a'), ord('z')+1)] + \ + [chr(x) for x in range(ord('A'), ord('Z')+1)] + \ + [chr(x) for x in range(ord('0'), ord('9')+1)] + \ + ['!', '*', '+', '-', '/', ' '] + self.hnon = [chr(x) for x in range(256) if chr(x) not in self.hlit] + assert len(self.hlit) + len(self.hnon) == 256 + self.blit = [chr(x) for x in range(ord(' '), ord('~')+1)] + ['\t'] + self.blit.remove('=') + self.bnon = [chr(x) for x in range(256) if chr(x) not in self.blit] + assert len(self.blit) + len(self.bnon) == 256 + + def test_header_quopri_check(self): + for c in self.hlit: + self.assertFalse(quopriMIME.header_quopri_check(c)) + for c in self.hnon: + self.assertTrue(quopriMIME.header_quopri_check(c)) + + def test_body_quopri_check(self): + for c in self.blit: + self.assertFalse(quopriMIME.body_quopri_check(c)) + for c in self.bnon: + self.assertTrue(quopriMIME.body_quopri_check(c)) + + def test_header_quopri_len(self): + eq = self.assertEqual + hql = quopriMIME.header_quopri_len + enc = quopriMIME.header_encode + for s in ('hello', 'h at e@l at l@o@'): + # Empty charset and no line-endings. 7 == RFC chrome + eq(hql(s), len(enc(s, charset='', eol=''))-7) + for c in self.hlit: + eq(hql(c), 1) + for c in self.hnon: + eq(hql(c), 3) + + def test_body_quopri_len(self): + eq = self.assertEqual + bql = quopriMIME.body_quopri_len + for c in self.blit: + eq(bql(c), 1) + for c in self.bnon: + eq(bql(c), 3) + + def test_quote_unquote_idempotent(self): + for x in range(256): + c = chr(x) + self.assertEqual(quopriMIME.unquote(quopriMIME.quote(c)), c) + + def test_header_encode(self): + eq = self.assertEqual + he = quopriMIME.header_encode + eq(he('hello'), '=?iso-8859-1?q?hello?=') + eq(he('hello\nworld'), '=?iso-8859-1?q?hello=0D=0Aworld?=') + # Test the charset option + eq(he('hello', charset='iso-8859-2'), '=?iso-8859-2?q?hello?=') + # Test the keep_eols flag + eq(he('hello\nworld', keep_eols=True), '=?iso-8859-1?q?hello=0Aworld?=') + # Test a non-ASCII character + eq(he('hello\xc7there'), '=?iso-8859-1?q?hello=C7there?=') + # Test the maxlinelen argument + eq(he('xxxx ' * 20, maxlinelen=40), """\ +=?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xx?= + =?iso-8859-1?q?xx_xxxx_xxxx_xxxx_xxxx?= + =?iso-8859-1?q?_xxxx_xxxx_xxxx_xxxx_x?= + =?iso-8859-1?q?xxx_xxxx_xxxx_xxxx_xxx?= + =?iso-8859-1?q?x_xxxx_xxxx_?=""") + # Test the eol argument + eq(he('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\ +=?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xx?=\r + =?iso-8859-1?q?xx_xxxx_xxxx_xxxx_xxxx?=\r + =?iso-8859-1?q?_xxxx_xxxx_xxxx_xxxx_x?=\r + =?iso-8859-1?q?xxx_xxxx_xxxx_xxxx_xxx?=\r + =?iso-8859-1?q?x_xxxx_xxxx_?=""") + + def test_decode(self): + eq = self.assertEqual + eq(quopriMIME.decode(''), '') + eq(quopriMIME.decode('hello'), 'hello') + eq(quopriMIME.decode('hello', 'X'), 'hello') + eq(quopriMIME.decode('hello\nworld', 'X'), 'helloXworld') + + def test_encode(self): + eq = self.assertEqual + eq(quopriMIME.encode(''), '') + eq(quopriMIME.encode('hello'), 'hello') + # Test the binary flag + eq(quopriMIME.encode('hello\r\nworld'), 'hello\nworld') + eq(quopriMIME.encode('hello\r\nworld', 0), 'hello\nworld') + # Test the maxlinelen arg + eq(quopriMIME.encode('xxxx ' * 20, maxlinelen=40), """\ +xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx= + xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx= +x xxxx xxxx xxxx xxxx=20""") + # Test the eol argument + eq(quopriMIME.encode('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\ +xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx=\r + xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx=\r +x xxxx xxxx xxxx xxxx=20""") + eq(quopriMIME.encode("""\ +one line + +two line"""), """\ +one line + +two line""") + + + +# Test the Charset class +class TestCharset(unittest.TestCase): + def tearDown(self): + from email import Charset as CharsetModule + try: + del CharsetModule.CHARSETS['fake'] + except KeyError: + pass + + def test_idempotent(self): + eq = self.assertEqual + # Make sure us-ascii = no Unicode conversion + c = Charset('us-ascii') + s = 'Hello World!' + sp = c.to_splittable(s) + eq(s, c.from_splittable(sp)) + # test 8-bit idempotency with us-ascii + s = '\xa4\xa2\xa4\xa4\xa4\xa6\xa4\xa8\xa4\xaa' + sp = c.to_splittable(s) + eq(s, c.from_splittable(sp)) + + def test_body_encode(self): + eq = self.assertEqual + # Try a charset with QP body encoding + c = Charset('iso-8859-1') + eq('hello w=F6rld', c.body_encode('hello w\xf6rld')) + # Try a charset with Base64 body encoding + c = Charset('utf-8') + eq('aGVsbG8gd29ybGQ=\n', c.body_encode('hello world')) + # Try a charset with None body encoding + c = Charset('us-ascii') + eq('hello world', c.body_encode('hello world')) + # Try the convert argument, where input codec != output codec + c = Charset('euc-jp') + # With apologies to Tokio Kikuchi ;) + if not is_jython: + # TODO Jython with its Java-based codecs does not + # currently support trailing bytes in CJK texts + try: + eq('\x1b$B5FCO;~IW\x1b(B', + c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7')) + eq('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7', + c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7', False)) + except LookupError: + # We probably don't have the Japanese codecs installed + pass + # Testing SF bug #625509, which we have to fake, since there are no + # built-in encodings where the header encoding is QP but the body + # encoding is not. + from email import Charset as CharsetModule + CharsetModule.add_charset('fake', CharsetModule.QP, None) + c = Charset('fake') + eq('hello w\xf6rld', c.body_encode('hello w\xf6rld')) + + def test_unicode_charset_name(self): + charset = Charset(u'us-ascii') + self.assertEqual(str(charset), 'us-ascii') + self.assertRaises(Errors.CharsetError, Charset, 'asc\xffii') + + def test_codecs_aliases_accepted(self): + charset = Charset('utf8') + self.assertEqual(str(charset), 'utf-8') + + +# Test multilingual MIME headers. +class TestHeader(TestEmailBase): + def test_simple(self): + eq = self.ndiffAssertEqual + h = Header('Hello World!') + eq(h.encode(), 'Hello World!') + h.append(' Goodbye World!') + eq(h.encode(), 'Hello World! Goodbye World!') + + def test_simple_surprise(self): + eq = self.ndiffAssertEqual + h = Header('Hello World!') + eq(h.encode(), 'Hello World!') + h.append('Goodbye World!') + eq(h.encode(), 'Hello World! Goodbye World!') + + def test_header_needs_no_decoding(self): + h = 'no decoding needed' + self.assertEqual(decode_header(h), [(h, None)]) + + def test_long(self): + h = Header("I am the very model of a modern Major-General; I've information vegetable, animal, and mineral; I know the kings of England, and I quote the fights historical from Marathon to Waterloo, in order categorical; I'm very well acquainted, too, with matters mathematical; I understand equations, both the simple and quadratical; about binomial theorem I'm teeming with a lot o' news, with many cheerful facts about the square of the hypotenuse.", + maxlinelen=76) + for l in h.encode(splitchars=' ').split('\n '): + self.assertTrue(len(l) <= 76) + + def test_multilingual(self): + eq = self.ndiffAssertEqual + g = Charset("iso-8859-1") + cz = Charset("iso-8859-2") + utf8 = Charset("utf-8") + g_head = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. " + cz_head = "Finan\xe8ni metropole se hroutily pod tlakem jejich d\xf9vtipu.. " + utf8_head = u"\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00\u90e8\u306f\u30c9\u30a4\u30c4\u8a9e\u3067\u3059\u304c\u3001\u3042\u3068\u306f\u3067\u305f\u3089\u3081\u3067\u3059\u3002\u5b9f\u969b\u306b\u306f\u300cWenn ist das Nunstuck git und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt gersput.\u300d\u3068\u8a00\u3063\u3066\u3044\u307e\u3059\u3002".encode("utf-8") + h = Header(g_head, g) + h.append(cz_head, cz) + h.append(utf8_head, utf8) + enc = h.encode() + eq(enc, """\ +=?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerderband_ko?= + =?iso-8859-1?q?mfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndischen_Wan?= + =?iso-8859-1?q?dgem=E4lden_vorbei=2C_gegen_die_rotierenden_Klingen_bef=F6?= + =?iso-8859-1?q?rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_metropole_se_hroutily?= + =?iso-8859-2?q?_pod_tlakem_jejich_d=F9vtipu=2E=2E_?= =?utf-8?b?5q2j56K6?= + =?utf-8?b?44Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE44G+44Gb44KT44CC?= + =?utf-8?b?5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB44GC44Go44Gv44Gn?= + =?utf-8?b?44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CMV2VubiBpc3QgZGFz?= + =?utf-8?q?_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das_Oder_die_Fl?= + =?utf-8?b?aXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBo+OBpuOBhOOBvuOBmQ==?= + =?utf-8?b?44CC?=""") + eq(decode_header(enc), + [(g_head, "iso-8859-1"), (cz_head, "iso-8859-2"), + (utf8_head, "utf-8")]) + ustr = unicode(h) + eq(ustr.encode('utf-8'), + 'Die Mieter treten hier ein werden mit einem Foerderband ' + 'komfortabel den Korridor entlang, an s\xc3\xbcdl\xc3\xbcndischen ' + 'Wandgem\xc3\xa4lden vorbei, gegen die rotierenden Klingen ' + 'bef\xc3\xb6rdert. Finan\xc4\x8dni metropole se hroutily pod ' + 'tlakem jejich d\xc5\xafvtipu.. \xe6\xad\xa3\xe7\xa2\xba\xe3\x81' + '\xab\xe8\xa8\x80\xe3\x81\x86\xe3\x81\xa8\xe7\xbf\xbb\xe8\xa8\xb3' + '\xe3\x81\xaf\xe3\x81\x95\xe3\x82\x8c\xe3\x81\xa6\xe3\x81\x84\xe3' + '\x81\xbe\xe3\x81\x9b\xe3\x82\x93\xe3\x80\x82\xe4\xb8\x80\xe9\x83' + '\xa8\xe3\x81\xaf\xe3\x83\x89\xe3\x82\xa4\xe3\x83\x84\xe8\xaa\x9e' + '\xe3\x81\xa7\xe3\x81\x99\xe3\x81\x8c\xe3\x80\x81\xe3\x81\x82\xe3' + '\x81\xa8\xe3\x81\xaf\xe3\x81\xa7\xe3\x81\x9f\xe3\x82\x89\xe3\x82' + '\x81\xe3\x81\xa7\xe3\x81\x99\xe3\x80\x82\xe5\xae\x9f\xe9\x9a\x9b' + '\xe3\x81\xab\xe3\x81\xaf\xe3\x80\x8cWenn ist das Nunstuck git ' + 'und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt ' + 'gersput.\xe3\x80\x8d\xe3\x81\xa8\xe8\xa8\x80\xe3\x81\xa3\xe3\x81' + '\xa6\xe3\x81\x84\xe3\x81\xbe\xe3\x81\x99\xe3\x80\x82') + # Test make_header() + newh = make_header(decode_header(enc)) + eq(newh, enc) + + def test_header_ctor_default_args(self): + eq = self.ndiffAssertEqual + h = Header() + eq(h, '') + h.append('foo', Charset('iso-8859-1')) + eq(h, '=?iso-8859-1?q?foo?=') + + def test_explicit_maxlinelen(self): + eq = self.ndiffAssertEqual + hstr = 'A very long line that must get split to something other than at the 76th character boundary to test the non-default behavior' + h = Header(hstr) + eq(h.encode(), '''\ +A very long line that must get split to something other than at the 76th + character boundary to test the non-default behavior''') + h = Header(hstr, header_name='Subject') + eq(h.encode(), '''\ +A very long line that must get split to something other than at the + 76th character boundary to test the non-default behavior''') + h = Header(hstr, maxlinelen=1024, header_name='Subject') + eq(h.encode(), hstr) + + def test_us_ascii_header(self): + eq = self.assertEqual + s = 'hello' + x = decode_header(s) + eq(x, [('hello', None)]) + h = make_header(x) + eq(s, h.encode()) + + def test_string_charset(self): + eq = self.assertEqual + h = Header() + h.append('hello', 'iso-8859-1') + eq(h, '=?iso-8859-1?q?hello?=') + +## def test_unicode_error(self): +## raises = self.assertRaises +## raises(UnicodeError, Header, u'[P\xf6stal]', 'us-ascii') +## raises(UnicodeError, Header, '[P\xf6stal]', 'us-ascii') +## h = Header() +## raises(UnicodeError, h.append, u'[P\xf6stal]', 'us-ascii') +## raises(UnicodeError, h.append, '[P\xf6stal]', 'us-ascii') +## raises(UnicodeError, Header, u'\u83ca\u5730\u6642\u592b', 'iso-8859-1') + + def test_utf8_shortest(self): + eq = self.assertEqual + h = Header(u'p\xf6stal', 'utf-8') + eq(h.encode(), '=?utf-8?q?p=C3=B6stal?=') + h = Header(u'\u83ca\u5730\u6642\u592b', 'utf-8') + eq(h.encode(), '=?utf-8?b?6I+K5Zyw5pmC5aSr?=') + + def test_bad_8bit_header(self): + raises = self.assertRaises + eq = self.assertEqual + x = 'Ynwp4dUEbay Auction Semiar- No Charge \x96 Earn Big' + raises(UnicodeError, Header, x) + h = Header() + raises(UnicodeError, h.append, x) + eq(str(Header(x, errors='replace')), x) + h.append(x, errors='replace') + eq(str(h), x) + + def test_encoded_adjacent_nonencoded(self): + eq = self.assertEqual + h = Header() + h.append('hello', 'iso-8859-1') + h.append('world') + s = h.encode() + eq(s, '=?iso-8859-1?q?hello?= world') + h = make_header(decode_header(s)) + eq(h.encode(), s) + + def test_whitespace_eater(self): + eq = self.assertEqual + s = 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztk=?= =?koi8-r?q?=CA?= zz.' + parts = decode_header(s) + eq(parts, [('Subject:', None), ('\xf0\xd2\xcf\xd7\xc5\xd2\xcb\xc1 \xce\xc1 \xc6\xc9\xce\xc1\xcc\xd8\xce\xd9\xca', 'koi8-r'), ('zz.', None)]) + hdr = make_header(parts) + eq(hdr.encode(), + 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztnK?= zz.') + + def test_broken_base64_header(self): + raises = self.assertRaises + s = 'Subject: =?EUC-KR?B?CSixpLDtKSC/7Liuvsax4iC6uLmwMcijIKHaILzSwd/H0SC8+LCjwLsgv7W/+Mj3I ?=' + raises(Errors.HeaderParseError, decode_header, s) + + # Issue 1078919 + def test_ascii_add_header(self): + msg = Message() + msg.add_header('Content-Disposition', 'attachment', + filename='bud.gif') + self.assertEqual('attachment; filename="bud.gif"', + msg['Content-Disposition']) + + def test_nonascii_add_header_via_triple(self): + msg = Message() + msg.add_header('Content-Disposition', 'attachment', + filename=('iso-8859-1', '', 'Fu\xdfballer.ppt')) + self.assertEqual( + 'attachment; filename*="iso-8859-1\'\'Fu%DFballer.ppt"', + msg['Content-Disposition']) + + def test_encode_unaliased_charset(self): + # Issue 1379416: when the charset has no output conversion, + # output was accidentally getting coerced to unicode. + res = Header('abc','iso-8859-2').encode() + self.assertEqual(res, '=?iso-8859-2?q?abc?=') + self.assertIsInstance(res, str) + + +# Test RFC 2231 header parameters (en/de)coding +class TestRFC2231(TestEmailBase): + def test_get_param(self): + eq = self.assertEqual + msg = self._msgobj('msg_29.txt') + eq(msg.get_param('title'), + ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!')) + eq(msg.get_param('title', unquote=False), + ('us-ascii', 'en', '"This is even more ***fun*** isn\'t it!"')) + + def test_set_param(self): + eq = self.assertEqual + msg = Message() + msg.set_param('title', 'This is even more ***fun*** isn\'t it!', + charset='us-ascii') + eq(msg.get_param('title'), + ('us-ascii', '', 'This is even more ***fun*** isn\'t it!')) + msg.set_param('title', 'This is even more ***fun*** isn\'t it!', + charset='us-ascii', language='en') + eq(msg.get_param('title'), + ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!')) + msg = self._msgobj('msg_01.txt') + msg.set_param('title', 'This is even more ***fun*** isn\'t it!', + charset='us-ascii', language='en') + self.ndiffAssertEqual(msg.as_string(), """\ +Return-Path: +Delivered-To: bbb at zzz.org +Received: by mail.zzz.org (Postfix, from userid 889) + id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT) +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit +Message-ID: <15090.61304.110929.45684 at aaa.zzz.org> +From: bbb at ddd.com (John X. Doe) +To: bbb at zzz.org +Subject: This is a test message +Date: Fri, 4 May 2001 14:05:44 -0400 +Content-Type: text/plain; charset=us-ascii; + title*="us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21" + + +Hi, + +Do you like this message? + +-Me +""") + + def test_del_param(self): + eq = self.ndiffAssertEqual + msg = self._msgobj('msg_01.txt') + msg.set_param('foo', 'bar', charset='us-ascii', language='en') + msg.set_param('title', 'This is even more ***fun*** isn\'t it!', + charset='us-ascii', language='en') + msg.del_param('foo', header='Content-Type') + eq(msg.as_string(), """\ +Return-Path: +Delivered-To: bbb at zzz.org +Received: by mail.zzz.org (Postfix, from userid 889) + id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT) +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit +Message-ID: <15090.61304.110929.45684 at aaa.zzz.org> +From: bbb at ddd.com (John X. Doe) +To: bbb at zzz.org +Subject: This is a test message +Date: Fri, 4 May 2001 14:05:44 -0400 +Content-Type: text/plain; charset="us-ascii"; + title*="us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21" + + +Hi, + +Do you like this message? + +-Me +""") + + def test_rfc2231_get_content_charset(self): + eq = self.assertEqual + msg = self._msgobj('msg_32.txt') + eq(msg.get_content_charset(), 'us-ascii') + + def test_rfc2231_no_language_or_charset(self): + m = '''\ +Content-Transfer-Encoding: 8bit +Content-Disposition: inline; filename="file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm" +Content-Type: text/html; NAME*0=file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEM; NAME*1=P_nsmail.htm + +''' + msg = email.message_from_string(m) + param = msg.get_param('NAME') + self.assertFalse(isinstance(param, tuple)) + self.assertEqual( + param, + 'file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm') + + def test_rfc2231_no_language_or_charset_in_filename(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0*="''This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), + 'This is even more ***fun*** is it not.pdf') + + def test_rfc2231_no_language_or_charset_in_filename_encoded(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0*="''This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), + 'This is even more ***fun*** is it not.pdf') + + def test_rfc2231_partly_encoded(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0="''This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual( + msg.get_filename(), + 'This%20is%20even%20more%20***fun*** is it not.pdf') + + def test_rfc2231_partly_nonencoded(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0="This%20is%20even%20more%20"; +\tfilename*1="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual( + msg.get_filename(), + 'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20is it not.pdf') + + def test_rfc2231_no_language_or_charset_in_boundary(self): + m = '''\ +Content-Type: multipart/alternative; +\tboundary*0*="''This%20is%20even%20more%20"; +\tboundary*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tboundary*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_boundary(), + 'This is even more ***fun*** is it not.pdf') + + def test_rfc2231_no_language_or_charset_in_charset(self): + # This is a nonsensical charset value, but tests the code anyway + m = '''\ +Content-Type: text/plain; +\tcharset*0*="This%20is%20even%20more%20"; +\tcharset*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tcharset*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_content_charset(), + 'this is even more ***fun*** is it not.pdf') + + def test_rfc2231_bad_encoding_in_filename(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0*="bogus'xx'This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), + 'This is even more ***fun*** is it not.pdf') + + def test_rfc2231_bad_encoding_in_charset(self): + m = """\ +Content-Type: text/plain; charset*=bogus''utf-8%E2%80%9D + +""" + msg = email.message_from_string(m) + # This should return None because non-ascii characters in the charset + # are not allowed. + self.assertEqual(msg.get_content_charset(), None) + + def test_rfc2231_bad_character_in_charset(self): + m = """\ +Content-Type: text/plain; charset*=ascii''utf-8%E2%80%9D + +""" + msg = email.message_from_string(m) + # This should return None because non-ascii characters in the charset + # are not allowed. + self.assertEqual(msg.get_content_charset(), None) + + def test_rfc2231_bad_character_in_filename(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0*="ascii'xx'This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2*="is it not.pdf%E2" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), + u'This is even more ***fun*** is it not.pdf\ufffd') + + def test_rfc2231_unknown_encoding(self): + m = """\ +Content-Transfer-Encoding: 8bit +Content-Disposition: inline; filename*=X-UNKNOWN''myfile.txt + +""" + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), 'myfile.txt') + + def test_rfc2231_single_tick_in_filename_extended(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; +\tname*0*=\"Frank's\"; name*1*=\" Document\" + +""" + msg = email.message_from_string(m) + charset, language, s = msg.get_param('name') + eq(charset, None) + eq(language, None) + eq(s, "Frank's Document") + + def test_rfc2231_single_tick_in_filename(self): + m = """\ +Content-Type: application/x-foo; name*0=\"Frank's\"; name*1=\" Document\" + +""" + msg = email.message_from_string(m) + param = msg.get_param('name') + self.assertFalse(isinstance(param, tuple)) + self.assertEqual(param, "Frank's Document") + + def test_rfc2231_tick_attack_extended(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; +\tname*0*=\"us-ascii'en-us'Frank's\"; name*1*=\" Document\" + +""" + msg = email.message_from_string(m) + charset, language, s = msg.get_param('name') + eq(charset, 'us-ascii') + eq(language, 'en-us') + eq(s, "Frank's Document") + + def test_rfc2231_tick_attack(self): + m = """\ +Content-Type: application/x-foo; +\tname*0=\"us-ascii'en-us'Frank's\"; name*1=\" Document\" + +""" + msg = email.message_from_string(m) + param = msg.get_param('name') + self.assertFalse(isinstance(param, tuple)) + self.assertEqual(param, "us-ascii'en-us'Frank's Document") + + def test_rfc2231_no_extended_values(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; name=\"Frank's Document\" + +""" + msg = email.message_from_string(m) + eq(msg.get_param('name'), "Frank's Document") + + def test_rfc2231_encoded_then_unencoded_segments(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; +\tname*0*=\"us-ascii'en-us'My\"; +\tname*1=\" Document\"; +\tname*2*=\" For You\" + +""" + msg = email.message_from_string(m) + charset, language, s = msg.get_param('name') + eq(charset, 'us-ascii') + eq(language, 'en-us') + eq(s, 'My Document For You') + + def test_rfc2231_unencoded_then_encoded_segments(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; +\tname*0=\"us-ascii'en-us'My\"; +\tname*1*=\" Document\"; +\tname*2*=\" For You\" + +""" + msg = email.message_from_string(m) + charset, language, s = msg.get_param('name') + eq(charset, 'us-ascii') + eq(language, 'en-us') + eq(s, 'My Document For You') + + + +# Tests to ensure that signed parts of an email are completely preserved, as +# required by RFC1847 section 2.1. Note that these are incomplete, because the +# email package does not currently always preserve the body. See issue 1670765. +class TestSigned(TestEmailBase): + + def _msg_and_obj(self, filename): + fp = openfile(findfile(filename)) + try: + original = fp.read() + msg = email.message_from_string(original) + finally: + fp.close() + return original, msg + + def _signed_parts_eq(self, original, result): + # Extract the first mime part of each message + import re + repart = re.compile(r'^--([^\n]+)\n(.*?)\n--\1$', re.S | re.M) + inpart = repart.search(original).group(2) + outpart = repart.search(result).group(2) + self.assertEqual(outpart, inpart) + + def test_long_headers_as_string(self): + original, msg = self._msg_and_obj('msg_45.txt') + result = msg.as_string() + self._signed_parts_eq(original, result) + + def test_long_headers_flatten(self): + original, msg = self._msg_and_obj('msg_45.txt') + fp = StringIO() + Generator(fp).flatten(msg) + result = fp.getvalue() + self._signed_parts_eq(original, result) + + + +def _testclasses(): + mod = sys.modules[__name__] + return [getattr(mod, name) for name in dir(mod) if name.startswith('Test')] + + +def suite(): + suite = unittest.TestSuite() + for testclass in _testclasses(): + suite.addTest(unittest.makeSuite(testclass)) + return suite + + +def test_main(): + for testclass in _testclasses(): + run_unittest(testclass) + + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/Lib/email/test/test_email_renamed.py b/Lib/email/test/test_email_renamed.py new file mode 100644 --- /dev/null +++ b/Lib/email/test/test_email_renamed.py @@ -0,0 +1,3332 @@ +# Copyright (C) 2001-2007 Python Software Foundation +# Contact: email-sig at python.org +# email package unit tests + +import os +import sys +import time +import base64 +import difflib +import unittest +import warnings +from cStringIO import StringIO + +import email + +from email.charset import Charset +from email.header import Header, decode_header, make_header +from email.parser import Parser, HeaderParser +from email.generator import Generator, DecodedGenerator +from email.message import Message +from email.mime.application import MIMEApplication +from email.mime.audio import MIMEAudio +from email.mime.text import MIMEText +from email.mime.image import MIMEImage +from email.mime.base import MIMEBase +from email.mime.message import MIMEMessage +from email.mime.multipart import MIMEMultipart +from email import utils +from email import errors +from email import encoders +from email import iterators +from email import base64mime +from email import quoprimime + +from test.test_support import findfile, run_unittest, is_jython +from email.test import __file__ as landmark + + +NL = '\n' +EMPTYSTRING = '' +SPACE = ' ' + + + +def openfile(filename, mode='r'): + path = os.path.join(os.path.dirname(landmark), 'data', filename) + return open(path, mode) + + + +# Base test class +class TestEmailBase(unittest.TestCase): + def ndiffAssertEqual(self, first, second): + """Like assertEqual except use ndiff for readable output.""" + if first != second: + sfirst = str(first) + ssecond = str(second) + diff = difflib.ndiff(sfirst.splitlines(), ssecond.splitlines()) + fp = StringIO() + print >> fp, NL, NL.join(diff) + raise self.failureException, fp.getvalue() + + def _msgobj(self, filename): + fp = openfile(findfile(filename)) + try: + msg = email.message_from_file(fp) + finally: + fp.close() + return msg + + + +# Test various aspects of the Message class's API +class TestMessageAPI(TestEmailBase): + def test_get_all(self): + eq = self.assertEqual + msg = self._msgobj('msg_20.txt') + eq(msg.get_all('cc'), ['ccc at zzz.org', 'ddd at zzz.org', 'eee at zzz.org']) + eq(msg.get_all('xx', 'n/a'), 'n/a') + + def test_getset_charset(self): + eq = self.assertEqual + msg = Message() + eq(msg.get_charset(), None) + charset = Charset('iso-8859-1') + msg.set_charset(charset) + eq(msg['mime-version'], '1.0') + eq(msg.get_content_type(), 'text/plain') + eq(msg['content-type'], 'text/plain; charset="iso-8859-1"') + eq(msg.get_param('charset'), 'iso-8859-1') + eq(msg['content-transfer-encoding'], 'quoted-printable') + eq(msg.get_charset().input_charset, 'iso-8859-1') + # Remove the charset + msg.set_charset(None) + eq(msg.get_charset(), None) + eq(msg['content-type'], 'text/plain') + # Try adding a charset when there's already MIME headers present + msg = Message() + msg['MIME-Version'] = '2.0' + msg['Content-Type'] = 'text/x-weird' + msg['Content-Transfer-Encoding'] = 'quinted-puntable' + msg.set_charset(charset) + eq(msg['mime-version'], '2.0') + eq(msg['content-type'], 'text/x-weird; charset="iso-8859-1"') + eq(msg['content-transfer-encoding'], 'quinted-puntable') + + def test_set_charset_from_string(self): + eq = self.assertEqual + msg = Message() + msg.set_charset('us-ascii') + eq(msg.get_charset().input_charset, 'us-ascii') + eq(msg['content-type'], 'text/plain; charset="us-ascii"') + + def test_set_payload_with_charset(self): + msg = Message() + charset = Charset('iso-8859-1') + msg.set_payload('This is a string payload', charset) + self.assertEqual(msg.get_charset().input_charset, 'iso-8859-1') + + def test_get_charsets(self): + eq = self.assertEqual + + msg = self._msgobj('msg_08.txt') + charsets = msg.get_charsets() + eq(charsets, [None, 'us-ascii', 'iso-8859-1', 'iso-8859-2', 'koi8-r']) + + msg = self._msgobj('msg_09.txt') + charsets = msg.get_charsets('dingbat') + eq(charsets, ['dingbat', 'us-ascii', 'iso-8859-1', 'dingbat', + 'koi8-r']) + + msg = self._msgobj('msg_12.txt') + charsets = msg.get_charsets() + eq(charsets, [None, 'us-ascii', 'iso-8859-1', None, 'iso-8859-2', + 'iso-8859-3', 'us-ascii', 'koi8-r']) + + def test_get_filename(self): + eq = self.assertEqual + + msg = self._msgobj('msg_04.txt') + filenames = [p.get_filename() for p in msg.get_payload()] + eq(filenames, ['msg.txt', 'msg.txt']) + + msg = self._msgobj('msg_07.txt') + subpart = msg.get_payload(1) + eq(subpart.get_filename(), 'dingusfish.gif') + + def test_get_filename_with_name_parameter(self): + eq = self.assertEqual + + msg = self._msgobj('msg_44.txt') + filenames = [p.get_filename() for p in msg.get_payload()] + eq(filenames, ['msg.txt', 'msg.txt']) + + def test_get_boundary(self): + eq = self.assertEqual + msg = self._msgobj('msg_07.txt') + # No quotes! + eq(msg.get_boundary(), 'BOUNDARY') + + def test_set_boundary(self): + eq = self.assertEqual + # This one has no existing boundary parameter, but the Content-Type: + # header appears fifth. + msg = self._msgobj('msg_01.txt') + msg.set_boundary('BOUNDARY') + header, value = msg.items()[4] + eq(header.lower(), 'content-type') + eq(value, 'text/plain; charset="us-ascii"; boundary="BOUNDARY"') + # This one has a Content-Type: header, with a boundary, stuck in the + # middle of its headers. Make sure the order is preserved; it should + # be fifth. + msg = self._msgobj('msg_04.txt') + msg.set_boundary('BOUNDARY') + header, value = msg.items()[4] + eq(header.lower(), 'content-type') + eq(value, 'multipart/mixed; boundary="BOUNDARY"') + # And this one has no Content-Type: header at all. + msg = self._msgobj('msg_03.txt') + self.assertRaises(errors.HeaderParseError, + msg.set_boundary, 'BOUNDARY') + + def test_get_decoded_payload(self): + eq = self.assertEqual + msg = self._msgobj('msg_10.txt') + # The outer message is a multipart + eq(msg.get_payload(decode=True), None) + # Subpart 1 is 7bit encoded + eq(msg.get_payload(0).get_payload(decode=True), + 'This is a 7bit encoded message.\n') + # Subpart 2 is quopri + eq(msg.get_payload(1).get_payload(decode=True), + '\xa1This is a Quoted Printable encoded message!\n') + # Subpart 3 is base64 + eq(msg.get_payload(2).get_payload(decode=True), + 'This is a Base64 encoded message.') + # Subpart 4 is base64 with a trailing newline, which + # used to be stripped (issue 7143). + eq(msg.get_payload(3).get_payload(decode=True), + 'This is a Base64 encoded message.\n') + # Subpart 5 has no Content-Transfer-Encoding: header. + eq(msg.get_payload(4).get_payload(decode=True), + 'This has no Content-Transfer-Encoding: header.\n') + + def test_get_decoded_uu_payload(self): + eq = self.assertEqual + msg = Message() + msg.set_payload('begin 666 -\n+:&5L;&\\@=V]R;&0 \n \nend\n') + for cte in ('x-uuencode', 'uuencode', 'uue', 'x-uue'): + msg['content-transfer-encoding'] = cte + eq(msg.get_payload(decode=True), 'hello world') + # Now try some bogus data + msg.set_payload('foo') + eq(msg.get_payload(decode=True), 'foo') + + def test_decoded_generator(self): + eq = self.assertEqual + msg = self._msgobj('msg_07.txt') + fp = openfile('msg_17.txt') + try: + text = fp.read() + finally: + fp.close() + s = StringIO() + g = DecodedGenerator(s) + g.flatten(msg) + eq(s.getvalue(), text) + + def test__contains__(self): + msg = Message() + msg['From'] = 'Me' + msg['to'] = 'You' + # Check for case insensitivity + self.assertTrue('from' in msg) + self.assertTrue('From' in msg) + self.assertTrue('FROM' in msg) + self.assertTrue('to' in msg) + self.assertTrue('To' in msg) + self.assertTrue('TO' in msg) + + def test_as_string(self): + eq = self.assertEqual + msg = self._msgobj('msg_01.txt') + fp = openfile('msg_01.txt') + try: + # BAW 30-Mar-2009 Evil be here. So, the generator is broken with + # respect to long line breaking. It's also not idempotent when a + # header from a parsed message is continued with tabs rather than + # spaces. Before we fixed bug 1974 it was reversedly broken, + # i.e. headers that were continued with spaces got continued with + # tabs. For Python 2.x there's really no good fix and in Python + # 3.x all this stuff is re-written to be right(er). Chris Withers + # convinced me that using space as the default continuation + # character is less bad for more applications. + text = fp.read().replace('\t', ' ') + finally: + fp.close() + self.ndiffAssertEqual(text, msg.as_string()) + fullrepr = str(msg) + lines = fullrepr.split('\n') + self.assertTrue(lines[0].startswith('From ')) + eq(text, NL.join(lines[1:])) + + def test_bad_param(self): + msg = email.message_from_string("Content-Type: blarg; baz; boo\n") + self.assertEqual(msg.get_param('baz'), '') + + def test_missing_filename(self): + msg = email.message_from_string("From: foo\n") + self.assertEqual(msg.get_filename(), None) + + def test_bogus_filename(self): + msg = email.message_from_string( + "Content-Disposition: blarg; filename\n") + self.assertEqual(msg.get_filename(), '') + + def test_missing_boundary(self): + msg = email.message_from_string("From: foo\n") + self.assertEqual(msg.get_boundary(), None) + + def test_get_params(self): + eq = self.assertEqual + msg = email.message_from_string( + 'X-Header: foo=one; bar=two; baz=three\n') + eq(msg.get_params(header='x-header'), + [('foo', 'one'), ('bar', 'two'), ('baz', 'three')]) + msg = email.message_from_string( + 'X-Header: foo; bar=one; baz=two\n') + eq(msg.get_params(header='x-header'), + [('foo', ''), ('bar', 'one'), ('baz', 'two')]) + eq(msg.get_params(), None) + msg = email.message_from_string( + 'X-Header: foo; bar="one"; baz=two\n') + eq(msg.get_params(header='x-header'), + [('foo', ''), ('bar', 'one'), ('baz', 'two')]) + + def test_get_param_liberal(self): + msg = Message() + msg['Content-Type'] = 'Content-Type: Multipart/mixed; boundary = "CPIMSSMTPC06p5f3tG"' + self.assertEqual(msg.get_param('boundary'), 'CPIMSSMTPC06p5f3tG') + + def test_get_param(self): + eq = self.assertEqual + msg = email.message_from_string( + "X-Header: foo=one; bar=two; baz=three\n") + eq(msg.get_param('bar', header='x-header'), 'two') + eq(msg.get_param('quuz', header='x-header'), None) + eq(msg.get_param('quuz'), None) + msg = email.message_from_string( + 'X-Header: foo; bar="one"; baz=two\n') + eq(msg.get_param('foo', header='x-header'), '') + eq(msg.get_param('bar', header='x-header'), 'one') + eq(msg.get_param('baz', header='x-header'), 'two') + # XXX: We are not RFC-2045 compliant! We cannot parse: + # msg["Content-Type"] = 'text/plain; weird="hey; dolly? [you] @ <\\"home\\">?"' + # msg.get_param("weird") + # yet. + + def test_get_param_funky_continuation_lines(self): + msg = self._msgobj('msg_22.txt') + self.assertEqual(msg.get_payload(1).get_param('name'), 'wibble.JPG') + + def test_get_param_with_semis_in_quotes(self): + msg = email.message_from_string( + 'Content-Type: image/pjpeg; name="Jim&&Jill"\n') + self.assertEqual(msg.get_param('name'), 'Jim&&Jill') + self.assertEqual(msg.get_param('name', unquote=False), + '"Jim&&Jill"') + + def test_has_key(self): + msg = email.message_from_string('Header: exists') + self.assertTrue(msg.has_key('header')) + self.assertTrue(msg.has_key('Header')) + self.assertTrue(msg.has_key('HEADER')) + self.assertFalse(msg.has_key('headeri')) + + def test_set_param(self): + eq = self.assertEqual + msg = Message() + msg.set_param('charset', 'iso-2022-jp') + eq(msg.get_param('charset'), 'iso-2022-jp') + msg.set_param('importance', 'high value') + eq(msg.get_param('importance'), 'high value') + eq(msg.get_param('importance', unquote=False), '"high value"') + eq(msg.get_params(), [('text/plain', ''), + ('charset', 'iso-2022-jp'), + ('importance', 'high value')]) + eq(msg.get_params(unquote=False), [('text/plain', ''), + ('charset', '"iso-2022-jp"'), + ('importance', '"high value"')]) + msg.set_param('charset', 'iso-9999-xx', header='X-Jimmy') + eq(msg.get_param('charset', header='X-Jimmy'), 'iso-9999-xx') + + def test_del_param(self): + eq = self.assertEqual + msg = self._msgobj('msg_05.txt') + eq(msg.get_params(), + [('multipart/report', ''), ('report-type', 'delivery-status'), + ('boundary', 'D1690A7AC1.996856090/mail.example.com')]) + old_val = msg.get_param("report-type") + msg.del_param("report-type") + eq(msg.get_params(), + [('multipart/report', ''), + ('boundary', 'D1690A7AC1.996856090/mail.example.com')]) + msg.set_param("report-type", old_val) + eq(msg.get_params(), + [('multipart/report', ''), + ('boundary', 'D1690A7AC1.996856090/mail.example.com'), + ('report-type', old_val)]) + + def test_del_param_on_other_header(self): + msg = Message() + msg.add_header('Content-Disposition', 'attachment', filename='bud.gif') + msg.del_param('filename', 'content-disposition') + self.assertEqual(msg['content-disposition'], 'attachment') + + def test_set_type(self): + eq = self.assertEqual + msg = Message() + self.assertRaises(ValueError, msg.set_type, 'text') + msg.set_type('text/plain') + eq(msg['content-type'], 'text/plain') + msg.set_param('charset', 'us-ascii') + eq(msg['content-type'], 'text/plain; charset="us-ascii"') + msg.set_type('text/html') + eq(msg['content-type'], 'text/html; charset="us-ascii"') + + def test_set_type_on_other_header(self): + msg = Message() + msg['X-Content-Type'] = 'text/plain' + msg.set_type('application/octet-stream', 'X-Content-Type') + self.assertEqual(msg['x-content-type'], 'application/octet-stream') + + def test_get_content_type_missing(self): + msg = Message() + self.assertEqual(msg.get_content_type(), 'text/plain') + + def test_get_content_type_missing_with_default_type(self): + msg = Message() + msg.set_default_type('message/rfc822') + self.assertEqual(msg.get_content_type(), 'message/rfc822') + + def test_get_content_type_from_message_implicit(self): + msg = self._msgobj('msg_30.txt') + self.assertEqual(msg.get_payload(0).get_content_type(), + 'message/rfc822') + + def test_get_content_type_from_message_explicit(self): + msg = self._msgobj('msg_28.txt') + self.assertEqual(msg.get_payload(0).get_content_type(), + 'message/rfc822') + + def test_get_content_type_from_message_text_plain_implicit(self): + msg = self._msgobj('msg_03.txt') + self.assertEqual(msg.get_content_type(), 'text/plain') + + def test_get_content_type_from_message_text_plain_explicit(self): + msg = self._msgobj('msg_01.txt') + self.assertEqual(msg.get_content_type(), 'text/plain') + + def test_get_content_maintype_missing(self): + msg = Message() + self.assertEqual(msg.get_content_maintype(), 'text') + + def test_get_content_maintype_missing_with_default_type(self): + msg = Message() + msg.set_default_type('message/rfc822') + self.assertEqual(msg.get_content_maintype(), 'message') + + def test_get_content_maintype_from_message_implicit(self): + msg = self._msgobj('msg_30.txt') + self.assertEqual(msg.get_payload(0).get_content_maintype(), 'message') + + def test_get_content_maintype_from_message_explicit(self): + msg = self._msgobj('msg_28.txt') + self.assertEqual(msg.get_payload(0).get_content_maintype(), 'message') + + def test_get_content_maintype_from_message_text_plain_implicit(self): + msg = self._msgobj('msg_03.txt') + self.assertEqual(msg.get_content_maintype(), 'text') + + def test_get_content_maintype_from_message_text_plain_explicit(self): + msg = self._msgobj('msg_01.txt') + self.assertEqual(msg.get_content_maintype(), 'text') + + def test_get_content_subtype_missing(self): + msg = Message() + self.assertEqual(msg.get_content_subtype(), 'plain') + + def test_get_content_subtype_missing_with_default_type(self): + msg = Message() + msg.set_default_type('message/rfc822') + self.assertEqual(msg.get_content_subtype(), 'rfc822') + + def test_get_content_subtype_from_message_implicit(self): + msg = self._msgobj('msg_30.txt') + self.assertEqual(msg.get_payload(0).get_content_subtype(), 'rfc822') + + def test_get_content_subtype_from_message_explicit(self): + msg = self._msgobj('msg_28.txt') + self.assertEqual(msg.get_payload(0).get_content_subtype(), 'rfc822') + + def test_get_content_subtype_from_message_text_plain_implicit(self): + msg = self._msgobj('msg_03.txt') + self.assertEqual(msg.get_content_subtype(), 'plain') + + def test_get_content_subtype_from_message_text_plain_explicit(self): + msg = self._msgobj('msg_01.txt') + self.assertEqual(msg.get_content_subtype(), 'plain') + + def test_get_content_maintype_error(self): + msg = Message() + msg['Content-Type'] = 'no-slash-in-this-string' + self.assertEqual(msg.get_content_maintype(), 'text') + + def test_get_content_subtype_error(self): + msg = Message() + msg['Content-Type'] = 'no-slash-in-this-string' + self.assertEqual(msg.get_content_subtype(), 'plain') + + def test_replace_header(self): + eq = self.assertEqual + msg = Message() + msg.add_header('First', 'One') + msg.add_header('Second', 'Two') + msg.add_header('Third', 'Three') + eq(msg.keys(), ['First', 'Second', 'Third']) + eq(msg.values(), ['One', 'Two', 'Three']) + msg.replace_header('Second', 'Twenty') + eq(msg.keys(), ['First', 'Second', 'Third']) + eq(msg.values(), ['One', 'Twenty', 'Three']) + msg.add_header('First', 'Eleven') + msg.replace_header('First', 'One Hundred') + eq(msg.keys(), ['First', 'Second', 'Third', 'First']) + eq(msg.values(), ['One Hundred', 'Twenty', 'Three', 'Eleven']) + self.assertRaises(KeyError, msg.replace_header, 'Fourth', 'Missing') + + def test_broken_base64_payload(self): + x = 'AwDp0P7//y6LwKEAcPa/6Q=9' + msg = Message() + msg['content-type'] = 'audio/x-midi' + msg['content-transfer-encoding'] = 'base64' + msg.set_payload(x) + self.assertEqual(msg.get_payload(decode=True), x) + + + +# Test the email.encoders module +class TestEncoders(unittest.TestCase): + def test_encode_empty_payload(self): + eq = self.assertEqual + msg = Message() + msg.set_charset('us-ascii') + eq(msg['content-transfer-encoding'], '7bit') + + def test_default_cte(self): + eq = self.assertEqual + msg = MIMEText('hello world') + eq(msg['content-transfer-encoding'], '7bit') + + def test_default_cte(self): + eq = self.assertEqual + # With no explicit _charset its us-ascii, and all are 7-bit + msg = MIMEText('hello world') + eq(msg['content-transfer-encoding'], '7bit') + # Similar, but with 8-bit data + msg = MIMEText('hello \xf8 world') + eq(msg['content-transfer-encoding'], '8bit') + # And now with a different charset + msg = MIMEText('hello \xf8 world', _charset='iso-8859-1') + eq(msg['content-transfer-encoding'], 'quoted-printable') + + + +# Test long header wrapping +class TestLongHeaders(TestEmailBase): + def test_split_long_continuation(self): + eq = self.ndiffAssertEqual + msg = email.message_from_string("""\ +Subject: bug demonstration +\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789 +\tmore text + +test +""") + sfp = StringIO() + g = Generator(sfp) + g.flatten(msg) + eq(sfp.getvalue(), """\ +Subject: bug demonstration + 12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789 + more text + +test +""") + + def test_another_long_almost_unsplittable_header(self): + eq = self.ndiffAssertEqual + hstr = """\ +bug demonstration +\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789 +\tmore text""" + h = Header(hstr, continuation_ws='\t') + eq(h.encode(), """\ +bug demonstration +\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789 +\tmore text""") + h = Header(hstr) + eq(h.encode(), """\ +bug demonstration + 12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789 + more text""") + + def test_long_nonstring(self): + eq = self.ndiffAssertEqual + g = Charset("iso-8859-1") + cz = Charset("iso-8859-2") + utf8 = Charset("utf-8") + g_head = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. " + cz_head = "Finan\xe8ni metropole se hroutily pod tlakem jejich d\xf9vtipu.. " + utf8_head = u"\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00\u90e8\u306f\u30c9\u30a4\u30c4\u8a9e\u3067\u3059\u304c\u3001\u3042\u3068\u306f\u3067\u305f\u3089\u3081\u3067\u3059\u3002\u5b9f\u969b\u306b\u306f\u300cWenn ist das Nunstuck git und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt gersput.\u300d\u3068\u8a00\u3063\u3066\u3044\u307e\u3059\u3002".encode("utf-8") + h = Header(g_head, g, header_name='Subject') + h.append(cz_head, cz) + h.append(utf8_head, utf8) + msg = Message() + msg['Subject'] = h + sfp = StringIO() + g = Generator(sfp) + g.flatten(msg) + eq(sfp.getvalue(), """\ +Subject: =?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerd?= + =?iso-8859-1?q?erband_komfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndi?= + =?iso-8859-1?q?schen_Wandgem=E4lden_vorbei=2C_gegen_die_rotierenden_Kling?= + =?iso-8859-1?q?en_bef=F6rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_met?= + =?iso-8859-2?q?ropole_se_hroutily_pod_tlakem_jejich_d=F9vtipu=2E=2E_?= + =?utf-8?b?5q2j56K644Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE?= + =?utf-8?b?44G+44Gb44KT44CC5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB?= + =?utf-8?b?44GC44Go44Gv44Gn44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CM?= + =?utf-8?q?Wenn_ist_das_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das?= + =?utf-8?b?IE9kZXIgZGllIEZsaXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBow==?= + =?utf-8?b?44Gm44GE44G+44GZ44CC?= + +""") + eq(h.encode(), """\ +=?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerd?= + =?iso-8859-1?q?erband_komfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndi?= + =?iso-8859-1?q?schen_Wandgem=E4lden_vorbei=2C_gegen_die_rotierenden_Kling?= + =?iso-8859-1?q?en_bef=F6rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_met?= + =?iso-8859-2?q?ropole_se_hroutily_pod_tlakem_jejich_d=F9vtipu=2E=2E_?= + =?utf-8?b?5q2j56K644Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE?= + =?utf-8?b?44G+44Gb44KT44CC5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB?= + =?utf-8?b?44GC44Go44Gv44Gn44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CM?= + =?utf-8?q?Wenn_ist_das_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das?= + =?utf-8?b?IE9kZXIgZGllIEZsaXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBow==?= + =?utf-8?b?44Gm44GE44G+44GZ44CC?=""") + + def test_long_header_encode(self): + eq = self.ndiffAssertEqual + h = Header('wasnipoop; giraffes="very-long-necked-animals"; ' + 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"', + header_name='X-Foobar-Spoink-Defrobnit') + eq(h.encode(), '''\ +wasnipoop; giraffes="very-long-necked-animals"; + spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"''') + + def test_long_header_encode_with_tab_continuation(self): + eq = self.ndiffAssertEqual + h = Header('wasnipoop; giraffes="very-long-necked-animals"; ' + 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"', + header_name='X-Foobar-Spoink-Defrobnit', + continuation_ws='\t') + eq(h.encode(), '''\ +wasnipoop; giraffes="very-long-necked-animals"; +\tspooge="yummy"; hippos="gargantuan"; marshmallows="gooey"''') + + def test_header_splitter(self): + eq = self.ndiffAssertEqual + msg = MIMEText('') + # It'd be great if we could use add_header() here, but that doesn't + # guarantee an order of the parameters. + msg['X-Foobar-Spoink-Defrobnit'] = ( + 'wasnipoop; giraffes="very-long-necked-animals"; ' + 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"') + sfp = StringIO() + g = Generator(sfp) + g.flatten(msg) + eq(sfp.getvalue(), '''\ +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit +X-Foobar-Spoink-Defrobnit: wasnipoop; giraffes="very-long-necked-animals"; + spooge="yummy"; hippos="gargantuan"; marshmallows="gooey" + +''') + + def test_no_semis_header_splitter(self): + eq = self.ndiffAssertEqual + msg = Message() + msg['From'] = 'test at dom.ain' + msg['References'] = SPACE.join(['<%d at dom.ain>' % i for i in range(10)]) + msg.set_payload('Test') + sfp = StringIO() + g = Generator(sfp) + g.flatten(msg) + eq(sfp.getvalue(), """\ +From: test at dom.ain +References: <0 at dom.ain> <1 at dom.ain> <2 at dom.ain> <3 at dom.ain> <4 at dom.ain> + <5 at dom.ain> <6 at dom.ain> <7 at dom.ain> <8 at dom.ain> <9 at dom.ain> + +Test""") + + def test_no_split_long_header(self): + eq = self.ndiffAssertEqual + hstr = 'References: ' + 'x' * 80 + h = Header(hstr, continuation_ws='\t') + eq(h.encode(), """\ +References: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx""") + + def test_splitting_multiple_long_lines(self): + eq = self.ndiffAssertEqual + hstr = """\ +from babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for ; Sat, 2 Feb 2002 17:00:06 -0800 (PST) +\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for ; Sat, 2 Feb 2002 17:00:06 -0800 (PST) +\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for ; Sat, 2 Feb 2002 17:00:06 -0800 (PST) +""" + h = Header(hstr, continuation_ws='\t') + eq(h.encode(), """\ +from babylon.socal-raves.org (localhost [127.0.0.1]); +\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; +\tfor ; +\tSat, 2 Feb 2002 17:00:06 -0800 (PST) +\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); +\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; +\tfor ; +\tSat, 2 Feb 2002 17:00:06 -0800 (PST) +\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); +\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; +\tfor ; +\tSat, 2 Feb 2002 17:00:06 -0800 (PST)""") + + def test_splitting_first_line_only_is_long(self): + eq = self.ndiffAssertEqual + hstr = """\ +from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93] helo=cthulhu.gerg.ca) +\tby kronos.mems-exchange.org with esmtp (Exim 4.05) +\tid 17k4h5-00034i-00 +\tfor test at mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400""" + h = Header(hstr, maxlinelen=78, header_name='Received', + continuation_ws='\t') + eq(h.encode(), """\ +from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93] +\thelo=cthulhu.gerg.ca) +\tby kronos.mems-exchange.org with esmtp (Exim 4.05) +\tid 17k4h5-00034i-00 +\tfor test at mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400""") + + def test_long_8bit_header(self): + eq = self.ndiffAssertEqual + msg = Message() + h = Header('Britische Regierung gibt', 'iso-8859-1', + header_name='Subject') + h.append('gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte') + msg['Subject'] = h + eq(msg.as_string(), """\ +Subject: =?iso-8859-1?q?Britische_Regierung_gibt?= =?iso-8859-1?q?gr=FCnes?= + =?iso-8859-1?q?_Licht_f=FCr_Offshore-Windkraftprojekte?= + +""") + + def test_long_8bit_header_no_charset(self): + eq = self.ndiffAssertEqual + msg = Message() + msg['Reply-To'] = 'Britische Regierung gibt gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte ' + eq(msg.as_string(), """\ +Reply-To: Britische Regierung gibt gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte + +""") + + def test_long_to_header(self): + eq = self.ndiffAssertEqual + to = '"Someone Test #A" ,,"Someone Test #B" , "Someone Test #C" , "Someone Test #D" ' + msg = Message() + msg['To'] = to + eq(msg.as_string(0), '''\ +To: "Someone Test #A" , , + "Someone Test #B" , + "Someone Test #C" , + "Someone Test #D" + +''') + + def test_long_line_after_append(self): + eq = self.ndiffAssertEqual + s = 'This is an example of string which has almost the limit of header length.' + h = Header(s) + h.append('Add another line.') + eq(h.encode(), """\ +This is an example of string which has almost the limit of header length. + Add another line.""") + + def test_shorter_line_with_append(self): + eq = self.ndiffAssertEqual + s = 'This is a shorter line.' + h = Header(s) + h.append('Add another sentence. (Surprise?)') + eq(h.encode(), + 'This is a shorter line. Add another sentence. (Surprise?)') + + def test_long_field_name(self): + eq = self.ndiffAssertEqual + fn = 'X-Very-Very-Very-Long-Header-Name' + gs = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. " + h = Header(gs, 'iso-8859-1', header_name=fn) + # BAW: this seems broken because the first line is too long + eq(h.encode(), """\ +=?iso-8859-1?q?Die_Mieter_treten_hier_?= + =?iso-8859-1?q?ein_werden_mit_einem_Foerderband_komfortabel_den_Korridor_?= + =?iso-8859-1?q?entlang=2C_an_s=FCdl=FCndischen_Wandgem=E4lden_vorbei=2C_g?= + =?iso-8859-1?q?egen_die_rotierenden_Klingen_bef=F6rdert=2E_?=""") + + def test_long_received_header(self): + h = 'from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP; Wed, 05 Mar 2003 18:10:18 -0700' + msg = Message() + msg['Received-1'] = Header(h, continuation_ws='\t') + msg['Received-2'] = h + self.ndiffAssertEqual(msg.as_string(), """\ +Received-1: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by +\throthgar.la.mastaler.com (tmda-ofmipd) with ESMTP; +\tWed, 05 Mar 2003 18:10:18 -0700 +Received-2: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by + hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP; + Wed, 05 Mar 2003 18:10:18 -0700 + +""") + + def test_string_headerinst_eq(self): + h = '<15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de> (David Bremner\'s message of "Thu, 6 Mar 2003 13:58:21 +0100")' + msg = Message() + msg['Received'] = Header(h, header_name='Received-1', + continuation_ws='\t') + msg['Received'] = h + self.ndiffAssertEqual(msg.as_string(), """\ +Received: <15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de> +\t(David Bremner's message of "Thu, 6 Mar 2003 13:58:21 +0100") +Received: <15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de> + (David Bremner's message of "Thu, 6 Mar 2003 13:58:21 +0100") + +""") + + def test_long_unbreakable_lines_with_continuation(self): + eq = self.ndiffAssertEqual + msg = Message() + t = """\ + iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9 + locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp""" + msg['Face-1'] = t + msg['Face-2'] = Header(t, header_name='Face-2') + eq(msg.as_string(), """\ +Face-1: iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9 + locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp +Face-2: iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9 + locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp + +""") + + def test_another_long_multiline_header(self): + eq = self.ndiffAssertEqual + m = '''\ +Received: from siimage.com ([172.25.1.3]) by zima.siliconimage.com with Microsoft SMTPSVC(5.0.2195.4905); + Wed, 16 Oct 2002 07:41:11 -0700''' + msg = email.message_from_string(m) + eq(msg.as_string(), '''\ +Received: from siimage.com ([172.25.1.3]) by zima.siliconimage.com with + Microsoft SMTPSVC(5.0.2195.4905); Wed, 16 Oct 2002 07:41:11 -0700 + +''') + + def test_long_lines_with_different_header(self): + eq = self.ndiffAssertEqual + h = """\ +List-Unsubscribe: , + """ + msg = Message() + msg['List'] = h + msg['List'] = Header(h, header_name='List') + self.ndiffAssertEqual(msg.as_string(), """\ +List: List-Unsubscribe: , + +List: List-Unsubscribe: , + + +""") + + + +# Test mangling of "From " lines in the body of a message +class TestFromMangling(unittest.TestCase): + def setUp(self): + self.msg = Message() + self.msg['From'] = 'aaa at bbb.org' + self.msg.set_payload("""\ +From the desk of A.A.A.: +Blah blah blah +""") + + def test_mangled_from(self): + s = StringIO() + g = Generator(s, mangle_from_=True) + g.flatten(self.msg) + self.assertEqual(s.getvalue(), """\ +From: aaa at bbb.org + +>From the desk of A.A.A.: +Blah blah blah +""") + + def test_dont_mangle_from(self): + s = StringIO() + g = Generator(s, mangle_from_=False) + g.flatten(self.msg) + self.assertEqual(s.getvalue(), """\ +From: aaa at bbb.org + +From the desk of A.A.A.: +Blah blah blah +""") + + + +# Test the basic MIMEAudio class +class TestMIMEAudio(unittest.TestCase): + def setUp(self): + # Make sure we pick up the audiotest.au that lives in email/test/data. + # In Python, there's an audiotest.au living in Lib/test but that isn't + # included in some binary distros that don't include the test + # package. The trailing empty string on the .join() is significant + # since findfile() will do a dirname(). + datadir = os.path.join(os.path.dirname(landmark), 'data', '') + fp = open(findfile('audiotest.au', datadir), 'rb') + try: + self._audiodata = fp.read() + finally: + fp.close() + self._au = MIMEAudio(self._audiodata) + + def test_guess_minor_type(self): + self.assertEqual(self._au.get_content_type(), 'audio/basic') + + def test_encoding(self): + payload = self._au.get_payload() + self.assertEqual(base64.decodestring(payload), self._audiodata) + + def test_checkSetMinor(self): + au = MIMEAudio(self._audiodata, 'fish') + self.assertEqual(au.get_content_type(), 'audio/fish') + + def test_add_header(self): + eq = self.assertEqual + unless = self.assertTrue + self._au.add_header('Content-Disposition', 'attachment', + filename='audiotest.au') + eq(self._au['content-disposition'], + 'attachment; filename="audiotest.au"') + eq(self._au.get_params(header='content-disposition'), + [('attachment', ''), ('filename', 'audiotest.au')]) + eq(self._au.get_param('filename', header='content-disposition'), + 'audiotest.au') + missing = [] + eq(self._au.get_param('attachment', header='content-disposition'), '') + unless(self._au.get_param('foo', failobj=missing, + header='content-disposition') is missing) + # Try some missing stuff + unless(self._au.get_param('foobar', missing) is missing) + unless(self._au.get_param('attachment', missing, + header='foobar') is missing) + + + +# Test the basic MIMEImage class +class TestMIMEImage(unittest.TestCase): + def setUp(self): + fp = openfile('PyBanner048.gif') + try: + self._imgdata = fp.read() + finally: + fp.close() + self._im = MIMEImage(self._imgdata) + + def test_guess_minor_type(self): + self.assertEqual(self._im.get_content_type(), 'image/gif') + + def test_encoding(self): + payload = self._im.get_payload() + self.assertEqual(base64.decodestring(payload), self._imgdata) + + def test_checkSetMinor(self): + im = MIMEImage(self._imgdata, 'fish') + self.assertEqual(im.get_content_type(), 'image/fish') + + def test_add_header(self): + eq = self.assertEqual + unless = self.assertTrue + self._im.add_header('Content-Disposition', 'attachment', + filename='dingusfish.gif') + eq(self._im['content-disposition'], + 'attachment; filename="dingusfish.gif"') + eq(self._im.get_params(header='content-disposition'), + [('attachment', ''), ('filename', 'dingusfish.gif')]) + eq(self._im.get_param('filename', header='content-disposition'), + 'dingusfish.gif') + missing = [] + eq(self._im.get_param('attachment', header='content-disposition'), '') + unless(self._im.get_param('foo', failobj=missing, + header='content-disposition') is missing) + # Try some missing stuff + unless(self._im.get_param('foobar', missing) is missing) + unless(self._im.get_param('attachment', missing, + header='foobar') is missing) + + + +# Test the basic MIMEApplication class +class TestMIMEApplication(unittest.TestCase): + def test_headers(self): + eq = self.assertEqual + msg = MIMEApplication('\xfa\xfb\xfc\xfd\xfe\xff') + eq(msg.get_content_type(), 'application/octet-stream') + eq(msg['content-transfer-encoding'], 'base64') + + def test_body(self): + eq = self.assertEqual + bytes = '\xfa\xfb\xfc\xfd\xfe\xff' + msg = MIMEApplication(bytes) + eq(msg.get_payload(), '+vv8/f7/') + eq(msg.get_payload(decode=True), bytes) + + def test_binary_body_with_encode_7or8bit(self): + # Issue 17171. + bytesdata = b'\xfa\xfb\xfc\xfd\xfe\xff' + msg = MIMEApplication(bytesdata, _encoder=encoders.encode_7or8bit) + # Treated as a string, this will be invalid code points. + self.assertEqual(msg.get_payload(), bytesdata) + self.assertEqual(msg.get_payload(decode=True), bytesdata) + self.assertEqual(msg['Content-Transfer-Encoding'], '8bit') + s = StringIO() + g = Generator(s) + g.flatten(msg) + wireform = s.getvalue() + msg2 = email.message_from_string(wireform) + self.assertEqual(msg.get_payload(), bytesdata) + self.assertEqual(msg2.get_payload(decode=True), bytesdata) + self.assertEqual(msg2['Content-Transfer-Encoding'], '8bit') + + def test_binary_body_with_encode_noop(self): + # Issue 16564: This does not produce an RFC valid message, since to be + # valid it should have a CTE of binary. But the below works, and is + # documented as working this way. + bytesdata = b'\xfa\xfb\xfc\xfd\xfe\xff' + msg = MIMEApplication(bytesdata, _encoder=encoders.encode_noop) + self.assertEqual(msg.get_payload(), bytesdata) + self.assertEqual(msg.get_payload(decode=True), bytesdata) + s = StringIO() + g = Generator(s) + g.flatten(msg) + wireform = s.getvalue() + msg2 = email.message_from_string(wireform) + self.assertEqual(msg.get_payload(), bytesdata) + self.assertEqual(msg2.get_payload(decode=True), bytesdata) + + +# Test the basic MIMEText class +class TestMIMEText(unittest.TestCase): + def setUp(self): + self._msg = MIMEText('hello there') + + def test_types(self): + eq = self.assertEqual + unless = self.assertTrue + eq(self._msg.get_content_type(), 'text/plain') + eq(self._msg.get_param('charset'), 'us-ascii') + missing = [] + unless(self._msg.get_param('foobar', missing) is missing) + unless(self._msg.get_param('charset', missing, header='foobar') + is missing) + + def test_payload(self): + self.assertEqual(self._msg.get_payload(), 'hello there') + self.assertTrue(not self._msg.is_multipart()) + + def test_charset(self): + eq = self.assertEqual + msg = MIMEText('hello there', _charset='us-ascii') + eq(msg.get_charset().input_charset, 'us-ascii') + eq(msg['content-type'], 'text/plain; charset="us-ascii"') + + + +# Test complicated multipart/* messages +class TestMultipart(TestEmailBase): + def setUp(self): + fp = openfile('PyBanner048.gif') + try: + data = fp.read() + finally: + fp.close() + + container = MIMEBase('multipart', 'mixed', boundary='BOUNDARY') + image = MIMEImage(data, name='dingusfish.gif') + image.add_header('content-disposition', 'attachment', + filename='dingusfish.gif') + intro = MIMEText('''\ +Hi there, + +This is the dingus fish. +''') + container.attach(intro) + container.attach(image) + container['From'] = 'Barry ' + container['To'] = 'Dingus Lovers ' + container['Subject'] = 'Here is your dingus fish' + + now = 987809702.54848599 + timetuple = time.localtime(now) + if timetuple[-1] == 0: + tzsecs = time.timezone + else: + tzsecs = time.altzone + if tzsecs > 0: + sign = '-' + else: + sign = '+' + tzoffset = ' %s%04d' % (sign, tzsecs // 36) + container['Date'] = time.strftime( + '%a, %d %b %Y %H:%M:%S', + time.localtime(now)) + tzoffset + self._msg = container + self._im = image + self._txt = intro + + def test_hierarchy(self): + # convenience + eq = self.assertEqual + unless = self.assertTrue + raises = self.assertRaises + # tests + m = self._msg + unless(m.is_multipart()) + eq(m.get_content_type(), 'multipart/mixed') + eq(len(m.get_payload()), 2) + raises(IndexError, m.get_payload, 2) + m0 = m.get_payload(0) + m1 = m.get_payload(1) + unless(m0 is self._txt) + unless(m1 is self._im) + eq(m.get_payload(), [m0, m1]) + unless(not m0.is_multipart()) + unless(not m1.is_multipart()) + + def test_empty_multipart_idempotent(self): + text = """\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + + +--BOUNDARY + + +--BOUNDARY-- +""" + msg = Parser().parsestr(text) + self.ndiffAssertEqual(text, msg.as_string()) + + def test_no_parts_in_a_multipart_with_none_epilogue(self): + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.set_boundary('BOUNDARY') + self.ndiffAssertEqual(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY + +--BOUNDARY--''') + + def test_no_parts_in_a_multipart_with_empty_epilogue(self): + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.preamble = '' + outer.epilogue = '' + outer.set_boundary('BOUNDARY') + self.ndiffAssertEqual(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + + +--BOUNDARY + +--BOUNDARY-- +''') + + def test_one_part_in_a_multipart(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.set_boundary('BOUNDARY') + msg = MIMEText('hello world') + outer.attach(msg) + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY--''') + + def test_seq_parts_in_a_multipart_with_empty_preamble(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.preamble = '' + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY--''') + + + def test_seq_parts_in_a_multipart_with_none_preamble(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.preamble = None + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY--''') + + + def test_seq_parts_in_a_multipart_with_none_epilogue(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.epilogue = None + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY--''') + + + def test_seq_parts_in_a_multipart_with_empty_epilogue(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.epilogue = '' + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY-- +''') + + + def test_seq_parts_in_a_multipart_with_nl_epilogue(self): + eq = self.ndiffAssertEqual + outer = MIMEBase('multipart', 'mixed') + outer['Subject'] = 'A subject' + outer['To'] = 'aperson at dom.ain' + outer['From'] = 'bperson at dom.ain' + outer.epilogue = '\n' + msg = MIMEText('hello world') + outer.attach(msg) + outer.set_boundary('BOUNDARY') + eq(outer.as_string(), '''\ +Content-Type: multipart/mixed; boundary="BOUNDARY" +MIME-Version: 1.0 +Subject: A subject +To: aperson at dom.ain +From: bperson at dom.ain + +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +hello world +--BOUNDARY-- + +''') + + def test_message_external_body(self): + eq = self.assertEqual + msg = self._msgobj('msg_36.txt') + eq(len(msg.get_payload()), 2) + msg1 = msg.get_payload(1) + eq(msg1.get_content_type(), 'multipart/alternative') + eq(len(msg1.get_payload()), 2) + for subpart in msg1.get_payload(): + eq(subpart.get_content_type(), 'message/external-body') + eq(len(subpart.get_payload()), 1) + subsubpart = subpart.get_payload(0) + eq(subsubpart.get_content_type(), 'text/plain') + + def test_double_boundary(self): + # msg_37.txt is a multipart that contains two dash-boundary's in a + # row. Our interpretation of RFC 2046 calls for ignoring the second + # and subsequent boundaries. + msg = self._msgobj('msg_37.txt') + self.assertEqual(len(msg.get_payload()), 3) + + def test_nested_inner_contains_outer_boundary(self): + eq = self.ndiffAssertEqual + # msg_38.txt has an inner part that contains outer boundaries. My + # interpretation of RFC 2046 (based on sections 5.1 and 5.1.2) say + # these are illegal and should be interpreted as unterminated inner + # parts. + msg = self._msgobj('msg_38.txt') + sfp = StringIO() + iterators._structure(msg, sfp) + eq(sfp.getvalue(), """\ +multipart/mixed + multipart/mixed + multipart/alternative + text/plain + text/plain + text/plain + text/plain +""") + + def test_nested_with_same_boundary(self): + eq = self.ndiffAssertEqual + # msg 39.txt is similarly evil in that it's got inner parts that use + # the same boundary as outer parts. Again, I believe the way this is + # parsed is closest to the spirit of RFC 2046 + msg = self._msgobj('msg_39.txt') + sfp = StringIO() + iterators._structure(msg, sfp) + eq(sfp.getvalue(), """\ +multipart/mixed + multipart/mixed + multipart/alternative + application/octet-stream + application/octet-stream + text/plain +""") + + def test_boundary_in_non_multipart(self): + msg = self._msgobj('msg_40.txt') + self.assertEqual(msg.as_string(), '''\ +MIME-Version: 1.0 +Content-Type: text/html; boundary="--961284236552522269" + +----961284236552522269 +Content-Type: text/html; +Content-Transfer-Encoding: 7Bit + + + +----961284236552522269-- +''') + + def test_boundary_with_leading_space(self): + eq = self.assertEqual + msg = email.message_from_string('''\ +MIME-Version: 1.0 +Content-Type: multipart/mixed; boundary=" XXXX" + +-- XXXX +Content-Type: text/plain + + +-- XXXX +Content-Type: text/plain + +-- XXXX-- +''') + self.assertTrue(msg.is_multipart()) + eq(msg.get_boundary(), ' XXXX') + eq(len(msg.get_payload()), 2) + + def test_boundary_without_trailing_newline(self): + m = Parser().parsestr("""\ +Content-Type: multipart/mixed; boundary="===============0012394164==" +MIME-Version: 1.0 + +--===============0012394164== +Content-Type: image/file1.jpg +MIME-Version: 1.0 +Content-Transfer-Encoding: base64 + +YXNkZg== +--===============0012394164==--""") + self.assertEqual(m.get_payload(0).get_payload(), 'YXNkZg==') + + + +# Test some badly formatted messages +class TestNonConformant(TestEmailBase): + def test_parse_missing_minor_type(self): + eq = self.assertEqual + msg = self._msgobj('msg_14.txt') + eq(msg.get_content_type(), 'text/plain') + eq(msg.get_content_maintype(), 'text') + eq(msg.get_content_subtype(), 'plain') + + def test_same_boundary_inner_outer(self): + unless = self.assertTrue + msg = self._msgobj('msg_15.txt') + # XXX We can probably eventually do better + inner = msg.get_payload(0) + unless(hasattr(inner, 'defects')) + self.assertEqual(len(inner.defects), 1) + unless(isinstance(inner.defects[0], + errors.StartBoundaryNotFoundDefect)) + + def test_multipart_no_boundary(self): + unless = self.assertTrue + msg = self._msgobj('msg_25.txt') + unless(isinstance(msg.get_payload(), str)) + self.assertEqual(len(msg.defects), 2) + unless(isinstance(msg.defects[0], errors.NoBoundaryInMultipartDefect)) + unless(isinstance(msg.defects[1], + errors.MultipartInvariantViolationDefect)) + + def test_invalid_content_type(self): + eq = self.assertEqual + neq = self.ndiffAssertEqual + msg = Message() + # RFC 2045, $5.2 says invalid yields text/plain + msg['Content-Type'] = 'text' + eq(msg.get_content_maintype(), 'text') + eq(msg.get_content_subtype(), 'plain') + eq(msg.get_content_type(), 'text/plain') + # Clear the old value and try something /really/ invalid + del msg['content-type'] + msg['Content-Type'] = 'foo' + eq(msg.get_content_maintype(), 'text') + eq(msg.get_content_subtype(), 'plain') + eq(msg.get_content_type(), 'text/plain') + # Still, make sure that the message is idempotently generated + s = StringIO() + g = Generator(s) + g.flatten(msg) + neq(s.getvalue(), 'Content-Type: foo\n\n') + + def test_no_start_boundary(self): + eq = self.ndiffAssertEqual + msg = self._msgobj('msg_31.txt') + eq(msg.get_payload(), """\ +--BOUNDARY +Content-Type: text/plain + +message 1 + +--BOUNDARY +Content-Type: text/plain + +message 2 + +--BOUNDARY-- +""") + + def test_no_separating_blank_line(self): + eq = self.ndiffAssertEqual + msg = self._msgobj('msg_35.txt') + eq(msg.as_string(), """\ +From: aperson at dom.ain +To: bperson at dom.ain +Subject: here's something interesting + +counter to RFC 2822, there's no separating newline here +""") + + def test_lying_multipart(self): + unless = self.assertTrue + msg = self._msgobj('msg_41.txt') + unless(hasattr(msg, 'defects')) + self.assertEqual(len(msg.defects), 2) + unless(isinstance(msg.defects[0], errors.NoBoundaryInMultipartDefect)) + unless(isinstance(msg.defects[1], + errors.MultipartInvariantViolationDefect)) + + def test_missing_start_boundary(self): + outer = self._msgobj('msg_42.txt') + # The message structure is: + # + # multipart/mixed + # text/plain + # message/rfc822 + # multipart/mixed [*] + # + # [*] This message is missing its start boundary + bad = outer.get_payload(1).get_payload(0) + self.assertEqual(len(bad.defects), 1) + self.assertTrue(isinstance(bad.defects[0], + errors.StartBoundaryNotFoundDefect)) + + def test_first_line_is_continuation_header(self): + eq = self.assertEqual + m = ' Line 1\nLine 2\nLine 3' + msg = email.message_from_string(m) + eq(msg.keys(), []) + eq(msg.get_payload(), 'Line 2\nLine 3') + eq(len(msg.defects), 1) + self.assertTrue(isinstance(msg.defects[0], + errors.FirstHeaderLineIsContinuationDefect)) + eq(msg.defects[0].line, ' Line 1\n') + + + +# Test RFC 2047 header encoding and decoding +class TestRFC2047(unittest.TestCase): + def test_rfc2047_multiline(self): + eq = self.assertEqual + s = """Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz + foo bar =?mac-iceland?q?r=8Aksm=9Arg=8Cs?=""" + dh = decode_header(s) + eq(dh, [ + ('Re:', None), + ('r\x8aksm\x9arg\x8cs', 'mac-iceland'), + ('baz foo bar', None), + ('r\x8aksm\x9arg\x8cs', 'mac-iceland')]) + eq(str(make_header(dh)), + """Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz foo bar + =?mac-iceland?q?r=8Aksm=9Arg=8Cs?=""") + + def test_whitespace_eater_unicode(self): + eq = self.assertEqual + s = '=?ISO-8859-1?Q?Andr=E9?= Pirard ' + dh = decode_header(s) + eq(dh, [('Andr\xe9', 'iso-8859-1'), ('Pirard ', None)]) + hu = unicode(make_header(dh)).encode('latin-1') + eq(hu, 'Andr\xe9 Pirard ') + + def test_whitespace_eater_unicode_2(self): + eq = self.assertEqual + s = 'The =?iso-8859-1?b?cXVpY2sgYnJvd24gZm94?= jumped over the =?iso-8859-1?b?bGF6eSBkb2c=?=' + dh = decode_header(s) + eq(dh, [('The', None), ('quick brown fox', 'iso-8859-1'), + ('jumped over the', None), ('lazy dog', 'iso-8859-1')]) + hu = make_header(dh).__unicode__() + eq(hu, u'The quick brown fox jumped over the lazy dog') + + def test_rfc2047_missing_whitespace(self): + s = 'Sm=?ISO-8859-1?B?9g==?=rg=?ISO-8859-1?B?5Q==?=sbord' + dh = decode_header(s) + self.assertEqual(dh, [(s, None)]) + + def test_rfc2047_with_whitespace(self): + s = 'Sm =?ISO-8859-1?B?9g==?= rg =?ISO-8859-1?B?5Q==?= sbord' + dh = decode_header(s) + self.assertEqual(dh, [('Sm', None), ('\xf6', 'iso-8859-1'), + ('rg', None), ('\xe5', 'iso-8859-1'), + ('sbord', None)]) + + + +# Test the MIMEMessage class +class TestMIMEMessage(TestEmailBase): + def setUp(self): + fp = openfile('msg_11.txt') + try: + self._text = fp.read() + finally: + fp.close() + + def test_type_error(self): + self.assertRaises(TypeError, MIMEMessage, 'a plain string') + + def test_valid_argument(self): + eq = self.assertEqual + unless = self.assertTrue + subject = 'A sub-message' + m = Message() + m['Subject'] = subject + r = MIMEMessage(m) + eq(r.get_content_type(), 'message/rfc822') + payload = r.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + subpart = payload[0] + unless(subpart is m) + eq(subpart['subject'], subject) + + def test_bad_multipart(self): + eq = self.assertEqual + msg1 = Message() + msg1['Subject'] = 'subpart 1' + msg2 = Message() + msg2['Subject'] = 'subpart 2' + r = MIMEMessage(msg1) + self.assertRaises(errors.MultipartConversionError, r.attach, msg2) + + def test_generate(self): + # First craft the message to be encapsulated + m = Message() + m['Subject'] = 'An enclosed message' + m.set_payload('Here is the body of the message.\n') + r = MIMEMessage(m) + r['Subject'] = 'The enclosing message' + s = StringIO() + g = Generator(s) + g.flatten(r) + self.assertEqual(s.getvalue(), """\ +Content-Type: message/rfc822 +MIME-Version: 1.0 +Subject: The enclosing message + +Subject: An enclosed message + +Here is the body of the message. +""") + + def test_parse_message_rfc822(self): + eq = self.assertEqual + unless = self.assertTrue + msg = self._msgobj('msg_11.txt') + eq(msg.get_content_type(), 'message/rfc822') + payload = msg.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + submsg = payload[0] + self.assertTrue(isinstance(submsg, Message)) + eq(submsg['subject'], 'An enclosed message') + eq(submsg.get_payload(), 'Here is the body of the message.\n') + + def test_dsn(self): + eq = self.assertEqual + unless = self.assertTrue + # msg 16 is a Delivery Status Notification, see RFC 1894 + msg = self._msgobj('msg_16.txt') + eq(msg.get_content_type(), 'multipart/report') + unless(msg.is_multipart()) + eq(len(msg.get_payload()), 3) + # Subpart 1 is a text/plain, human readable section + subpart = msg.get_payload(0) + eq(subpart.get_content_type(), 'text/plain') + eq(subpart.get_payload(), """\ +This report relates to a message you sent with the following header fields: + + Message-id: <002001c144a6$8752e060$56104586 at oxy.edu> + Date: Sun, 23 Sep 2001 20:10:55 -0700 + From: "Ian T. Henry" + To: SoCal Raves + Subject: [scr] yeah for Ians!! + +Your message cannot be delivered to the following recipients: + + Recipient address: jangel1 at cougar.noc.ucla.edu + Reason: recipient reached disk quota + +""") + # Subpart 2 contains the machine parsable DSN information. It + # consists of two blocks of headers, represented by two nested Message + # objects. + subpart = msg.get_payload(1) + eq(subpart.get_content_type(), 'message/delivery-status') + eq(len(subpart.get_payload()), 2) + # message/delivery-status should treat each block as a bunch of + # headers, i.e. a bunch of Message objects. + dsn1 = subpart.get_payload(0) + unless(isinstance(dsn1, Message)) + eq(dsn1['original-envelope-id'], '0GK500B4HD0888 at cougar.noc.ucla.edu') + eq(dsn1.get_param('dns', header='reporting-mta'), '') + # Try a missing one + eq(dsn1.get_param('nsd', header='reporting-mta'), None) + dsn2 = subpart.get_payload(1) + unless(isinstance(dsn2, Message)) + eq(dsn2['action'], 'failed') + eq(dsn2.get_params(header='original-recipient'), + [('rfc822', ''), ('jangel1 at cougar.noc.ucla.edu', '')]) + eq(dsn2.get_param('rfc822', header='final-recipient'), '') + # Subpart 3 is the original message + subpart = msg.get_payload(2) + eq(subpart.get_content_type(), 'message/rfc822') + payload = subpart.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + subsubpart = payload[0] + unless(isinstance(subsubpart, Message)) + eq(subsubpart.get_content_type(), 'text/plain') + eq(subsubpart['message-id'], + '<002001c144a6$8752e060$56104586 at oxy.edu>') + + def test_epilogue(self): + eq = self.ndiffAssertEqual + fp = openfile('msg_21.txt') + try: + text = fp.read() + finally: + fp.close() + msg = Message() + msg['From'] = 'aperson at dom.ain' + msg['To'] = 'bperson at dom.ain' + msg['Subject'] = 'Test' + msg.preamble = 'MIME message' + msg.epilogue = 'End of MIME message\n' + msg1 = MIMEText('One') + msg2 = MIMEText('Two') + msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY') + msg.attach(msg1) + msg.attach(msg2) + sfp = StringIO() + g = Generator(sfp) + g.flatten(msg) + eq(sfp.getvalue(), text) + + def test_no_nl_preamble(self): + eq = self.ndiffAssertEqual + msg = Message() + msg['From'] = 'aperson at dom.ain' + msg['To'] = 'bperson at dom.ain' + msg['Subject'] = 'Test' + msg.preamble = 'MIME message' + msg.epilogue = '' + msg1 = MIMEText('One') + msg2 = MIMEText('Two') + msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY') + msg.attach(msg1) + msg.attach(msg2) + eq(msg.as_string(), """\ +From: aperson at dom.ain +To: bperson at dom.ain +Subject: Test +Content-Type: multipart/mixed; boundary="BOUNDARY" + +MIME message +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +One +--BOUNDARY +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +Two +--BOUNDARY-- +""") + + def test_default_type(self): + eq = self.assertEqual + fp = openfile('msg_30.txt') + try: + msg = email.message_from_file(fp) + finally: + fp.close() + container1 = msg.get_payload(0) + eq(container1.get_default_type(), 'message/rfc822') + eq(container1.get_content_type(), 'message/rfc822') + container2 = msg.get_payload(1) + eq(container2.get_default_type(), 'message/rfc822') + eq(container2.get_content_type(), 'message/rfc822') + container1a = container1.get_payload(0) + eq(container1a.get_default_type(), 'text/plain') + eq(container1a.get_content_type(), 'text/plain') + container2a = container2.get_payload(0) + eq(container2a.get_default_type(), 'text/plain') + eq(container2a.get_content_type(), 'text/plain') + + def test_default_type_with_explicit_container_type(self): + eq = self.assertEqual + fp = openfile('msg_28.txt') + try: + msg = email.message_from_file(fp) + finally: + fp.close() + container1 = msg.get_payload(0) + eq(container1.get_default_type(), 'message/rfc822') + eq(container1.get_content_type(), 'message/rfc822') + container2 = msg.get_payload(1) + eq(container2.get_default_type(), 'message/rfc822') + eq(container2.get_content_type(), 'message/rfc822') + container1a = container1.get_payload(0) + eq(container1a.get_default_type(), 'text/plain') + eq(container1a.get_content_type(), 'text/plain') + container2a = container2.get_payload(0) + eq(container2a.get_default_type(), 'text/plain') + eq(container2a.get_content_type(), 'text/plain') + + def test_default_type_non_parsed(self): + eq = self.assertEqual + neq = self.ndiffAssertEqual + # Set up container + container = MIMEMultipart('digest', 'BOUNDARY') + container.epilogue = '' + # Set up subparts + subpart1a = MIMEText('message 1\n') + subpart2a = MIMEText('message 2\n') + subpart1 = MIMEMessage(subpart1a) + subpart2 = MIMEMessage(subpart2a) + container.attach(subpart1) + container.attach(subpart2) + eq(subpart1.get_content_type(), 'message/rfc822') + eq(subpart1.get_default_type(), 'message/rfc822') + eq(subpart2.get_content_type(), 'message/rfc822') + eq(subpart2.get_default_type(), 'message/rfc822') + neq(container.as_string(0), '''\ +Content-Type: multipart/digest; boundary="BOUNDARY" +MIME-Version: 1.0 + +--BOUNDARY +Content-Type: message/rfc822 +MIME-Version: 1.0 + +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +message 1 + +--BOUNDARY +Content-Type: message/rfc822 +MIME-Version: 1.0 + +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +message 2 + +--BOUNDARY-- +''') + del subpart1['content-type'] + del subpart1['mime-version'] + del subpart2['content-type'] + del subpart2['mime-version'] + eq(subpart1.get_content_type(), 'message/rfc822') + eq(subpart1.get_default_type(), 'message/rfc822') + eq(subpart2.get_content_type(), 'message/rfc822') + eq(subpart2.get_default_type(), 'message/rfc822') + neq(container.as_string(0), '''\ +Content-Type: multipart/digest; boundary="BOUNDARY" +MIME-Version: 1.0 + +--BOUNDARY + +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +message 1 + +--BOUNDARY + +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +message 2 + +--BOUNDARY-- +''') + + def test_mime_attachments_in_constructor(self): + eq = self.assertEqual + text1 = MIMEText('') + text2 = MIMEText('') + msg = MIMEMultipart(_subparts=(text1, text2)) + eq(len(msg.get_payload()), 2) + eq(msg.get_payload(0), text1) + eq(msg.get_payload(1), text2) + + + +# A general test of parser->model->generator idempotency. IOW, read a message +# in, parse it into a message object tree, then without touching the tree, +# regenerate the plain text. The original text and the transformed text +# should be identical. Note: that we ignore the Unix-From since that may +# contain a changed date. +class TestIdempotent(TestEmailBase): + def _msgobj(self, filename): + fp = openfile(filename) + try: + data = fp.read() + finally: + fp.close() + msg = email.message_from_string(data) + return msg, data + + def _idempotent(self, msg, text): + eq = self.ndiffAssertEqual + s = StringIO() + g = Generator(s, maxheaderlen=0) + g.flatten(msg) + eq(text, s.getvalue()) + + def test_parse_text_message(self): + eq = self.assertEqual + msg, text = self._msgobj('msg_01.txt') + eq(msg.get_content_type(), 'text/plain') + eq(msg.get_content_maintype(), 'text') + eq(msg.get_content_subtype(), 'plain') + eq(msg.get_params()[1], ('charset', 'us-ascii')) + eq(msg.get_param('charset'), 'us-ascii') + eq(msg.preamble, None) + eq(msg.epilogue, None) + self._idempotent(msg, text) + + def test_parse_untyped_message(self): + eq = self.assertEqual + msg, text = self._msgobj('msg_03.txt') + eq(msg.get_content_type(), 'text/plain') + eq(msg.get_params(), None) + eq(msg.get_param('charset'), None) + self._idempotent(msg, text) + + def test_simple_multipart(self): + msg, text = self._msgobj('msg_04.txt') + self._idempotent(msg, text) + + def test_MIME_digest(self): + msg, text = self._msgobj('msg_02.txt') + self._idempotent(msg, text) + + def test_long_header(self): + msg, text = self._msgobj('msg_27.txt') + self._idempotent(msg, text) + + def test_MIME_digest_with_part_headers(self): + msg, text = self._msgobj('msg_28.txt') + self._idempotent(msg, text) + + def test_mixed_with_image(self): + msg, text = self._msgobj('msg_06.txt') + self._idempotent(msg, text) + + def test_multipart_report(self): + msg, text = self._msgobj('msg_05.txt') + self._idempotent(msg, text) + + def test_dsn(self): + msg, text = self._msgobj('msg_16.txt') + self._idempotent(msg, text) + + def test_preamble_epilogue(self): + msg, text = self._msgobj('msg_21.txt') + self._idempotent(msg, text) + + def test_multipart_one_part(self): + msg, text = self._msgobj('msg_23.txt') + self._idempotent(msg, text) + + def test_multipart_no_parts(self): + msg, text = self._msgobj('msg_24.txt') + self._idempotent(msg, text) + + def test_no_start_boundary(self): + msg, text = self._msgobj('msg_31.txt') + self._idempotent(msg, text) + + def test_rfc2231_charset(self): + msg, text = self._msgobj('msg_32.txt') + self._idempotent(msg, text) + + def test_more_rfc2231_parameters(self): + msg, text = self._msgobj('msg_33.txt') + self._idempotent(msg, text) + + def test_text_plain_in_a_multipart_digest(self): + msg, text = self._msgobj('msg_34.txt') + self._idempotent(msg, text) + + def test_nested_multipart_mixeds(self): + msg, text = self._msgobj('msg_12a.txt') + self._idempotent(msg, text) + + def test_message_external_body_idempotent(self): + msg, text = self._msgobj('msg_36.txt') + self._idempotent(msg, text) + + def test_content_type(self): + eq = self.assertEqual + unless = self.assertTrue + # Get a message object and reset the seek pointer for other tests + msg, text = self._msgobj('msg_05.txt') + eq(msg.get_content_type(), 'multipart/report') + # Test the Content-Type: parameters + params = {} + for pk, pv in msg.get_params(): + params[pk] = pv + eq(params['report-type'], 'delivery-status') + eq(params['boundary'], 'D1690A7AC1.996856090/mail.example.com') + eq(msg.preamble, 'This is a MIME-encapsulated message.\n') + eq(msg.epilogue, '\n') + eq(len(msg.get_payload()), 3) + # Make sure the subparts are what we expect + msg1 = msg.get_payload(0) + eq(msg1.get_content_type(), 'text/plain') + eq(msg1.get_payload(), 'Yadda yadda yadda\n') + msg2 = msg.get_payload(1) + eq(msg2.get_content_type(), 'text/plain') + eq(msg2.get_payload(), 'Yadda yadda yadda\n') + msg3 = msg.get_payload(2) + eq(msg3.get_content_type(), 'message/rfc822') + self.assertTrue(isinstance(msg3, Message)) + payload = msg3.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + msg4 = payload[0] + unless(isinstance(msg4, Message)) + eq(msg4.get_payload(), 'Yadda yadda yadda\n') + + def test_parser(self): + eq = self.assertEqual + unless = self.assertTrue + msg, text = self._msgobj('msg_06.txt') + # Check some of the outer headers + eq(msg.get_content_type(), 'message/rfc822') + # Make sure the payload is a list of exactly one sub-Message, and that + # that submessage has a type of text/plain + payload = msg.get_payload() + unless(isinstance(payload, list)) + eq(len(payload), 1) + msg1 = payload[0] + self.assertTrue(isinstance(msg1, Message)) + eq(msg1.get_content_type(), 'text/plain') + self.assertTrue(isinstance(msg1.get_payload(), str)) + eq(msg1.get_payload(), '\n') + + + +# Test various other bits of the package's functionality +class TestMiscellaneous(TestEmailBase): + def test_message_from_string(self): + fp = openfile('msg_01.txt') + try: + text = fp.read() + finally: + fp.close() + msg = email.message_from_string(text) + s = StringIO() + # Don't wrap/continue long headers since we're trying to test + # idempotency. + g = Generator(s, maxheaderlen=0) + g.flatten(msg) + self.assertEqual(text, s.getvalue()) + + def test_message_from_file(self): + fp = openfile('msg_01.txt') + try: + text = fp.read() + fp.seek(0) + msg = email.message_from_file(fp) + s = StringIO() + # Don't wrap/continue long headers since we're trying to test + # idempotency. + g = Generator(s, maxheaderlen=0) + g.flatten(msg) + self.assertEqual(text, s.getvalue()) + finally: + fp.close() + + def test_message_from_string_with_class(self): + unless = self.assertTrue + fp = openfile('msg_01.txt') + try: + text = fp.read() + finally: + fp.close() + # Create a subclass + class MyMessage(Message): + pass + + msg = email.message_from_string(text, MyMessage) + unless(isinstance(msg, MyMessage)) + # Try something more complicated + fp = openfile('msg_02.txt') + try: + text = fp.read() + finally: + fp.close() + msg = email.message_from_string(text, MyMessage) + for subpart in msg.walk(): + unless(isinstance(subpart, MyMessage)) + + def test_message_from_file_with_class(self): + unless = self.assertTrue + # Create a subclass + class MyMessage(Message): + pass + + fp = openfile('msg_01.txt') + try: + msg = email.message_from_file(fp, MyMessage) + finally: + fp.close() + unless(isinstance(msg, MyMessage)) + # Try something more complicated + fp = openfile('msg_02.txt') + try: + msg = email.message_from_file(fp, MyMessage) + finally: + fp.close() + for subpart in msg.walk(): + unless(isinstance(subpart, MyMessage)) + + def test__all__(self): + module = __import__('email') + # Can't use sorted() here due to Python 2.3 compatibility + all = module.__all__[:] + all.sort() + self.assertEqual(all, [ + # Old names + 'Charset', 'Encoders', 'Errors', 'Generator', + 'Header', 'Iterators', 'MIMEAudio', 'MIMEBase', + 'MIMEImage', 'MIMEMessage', 'MIMEMultipart', + 'MIMENonMultipart', 'MIMEText', 'Message', + 'Parser', 'Utils', 'base64MIME', + # new names + 'base64mime', 'charset', 'encoders', 'errors', 'generator', + 'header', 'iterators', 'message', 'message_from_file', + 'message_from_string', 'mime', 'parser', + 'quopriMIME', 'quoprimime', 'utils', + ]) + + def test_formatdate(self): + now = time.time() + self.assertEqual(utils.parsedate(utils.formatdate(now))[:6], + time.gmtime(now)[:6]) + + def test_formatdate_localtime(self): + now = time.time() + self.assertEqual( + utils.parsedate(utils.formatdate(now, localtime=True))[:6], + time.localtime(now)[:6]) + + def test_formatdate_usegmt(self): + now = time.time() + self.assertEqual( + utils.formatdate(now, localtime=False), + time.strftime('%a, %d %b %Y %H:%M:%S -0000', time.gmtime(now))) + self.assertEqual( + utils.formatdate(now, localtime=False, usegmt=True), + time.strftime('%a, %d %b %Y %H:%M:%S GMT', time.gmtime(now))) + + def test_parsedate_none(self): + self.assertEqual(utils.parsedate(''), None) + + def test_parsedate_compact(self): + # The FWS after the comma is optional + self.assertEqual(utils.parsedate('Wed,3 Apr 2002 14:58:26 +0800'), + utils.parsedate('Wed, 3 Apr 2002 14:58:26 +0800')) + + def test_parsedate_no_dayofweek(self): + eq = self.assertEqual + eq(utils.parsedate_tz('25 Feb 2003 13:47:26 -0800'), + (2003, 2, 25, 13, 47, 26, 0, 1, -1, -28800)) + + def test_parsedate_compact_no_dayofweek(self): + eq = self.assertEqual + eq(utils.parsedate_tz('5 Feb 2003 13:47:26 -0800'), + (2003, 2, 5, 13, 47, 26, 0, 1, -1, -28800)) + + def test_parsedate_acceptable_to_time_functions(self): + eq = self.assertEqual + timetup = utils.parsedate('5 Feb 2003 13:47:26 -0800') + t = int(time.mktime(timetup)) + eq(time.localtime(t)[:6], timetup[:6]) + eq(int(time.strftime('%Y', timetup)), 2003) + timetup = utils.parsedate_tz('5 Feb 2003 13:47:26 -0800') + t = int(time.mktime(timetup[:9])) + eq(time.localtime(t)[:6], timetup[:6]) + eq(int(time.strftime('%Y', timetup[:9])), 2003) + + def test_parseaddr_empty(self): + self.assertEqual(utils.parseaddr('<>'), ('', '')) + self.assertEqual(utils.formataddr(utils.parseaddr('<>')), '') + + def test_noquote_dump(self): + self.assertEqual( + utils.formataddr(('A Silly Person', 'person at dom.ain')), + 'A Silly Person ') + + def test_escape_dump(self): + self.assertEqual( + utils.formataddr(('A (Very) Silly Person', 'person at dom.ain')), + r'"A \(Very\) Silly Person" ') + a = r'A \(Special\) Person' + b = 'person at dom.ain' + self.assertEqual(utils.parseaddr(utils.formataddr((a, b))), (a, b)) + + def test_escape_backslashes(self): + self.assertEqual( + utils.formataddr(('Arthur \Backslash\ Foobar', 'person at dom.ain')), + r'"Arthur \\Backslash\\ Foobar" ') + a = r'Arthur \Backslash\ Foobar' + b = 'person at dom.ain' + self.assertEqual(utils.parseaddr(utils.formataddr((a, b))), (a, b)) + + def test_name_with_dot(self): + x = 'John X. Doe ' + y = '"John X. Doe" ' + a, b = ('John X. Doe', 'jxd at example.com') + self.assertEqual(utils.parseaddr(x), (a, b)) + self.assertEqual(utils.parseaddr(y), (a, b)) + # formataddr() quotes the name if there's a dot in it + self.assertEqual(utils.formataddr((a, b)), y) + + def test_multiline_from_comment(self): + x = """\ +Foo +\tBar """ + self.assertEqual(utils.parseaddr(x), ('Foo Bar', 'foo at example.com')) + + def test_quote_dump(self): + self.assertEqual( + utils.formataddr(('A Silly; Person', 'person at dom.ain')), + r'"A Silly; Person" ') + + def test_fix_eols(self): + eq = self.assertEqual + eq(utils.fix_eols('hello'), 'hello') + eq(utils.fix_eols('hello\n'), 'hello\r\n') + eq(utils.fix_eols('hello\r'), 'hello\r\n') + eq(utils.fix_eols('hello\r\n'), 'hello\r\n') + eq(utils.fix_eols('hello\n\r'), 'hello\r\n\r\n') + + def test_charset_richcomparisons(self): + eq = self.assertEqual + ne = self.assertNotEqual + cset1 = Charset() + cset2 = Charset() + eq(cset1, 'us-ascii') + eq(cset1, 'US-ASCII') + eq(cset1, 'Us-AsCiI') + eq('us-ascii', cset1) + eq('US-ASCII', cset1) + eq('Us-AsCiI', cset1) + ne(cset1, 'usascii') + ne(cset1, 'USASCII') + ne(cset1, 'UsAsCiI') + ne('usascii', cset1) + ne('USASCII', cset1) + ne('UsAsCiI', cset1) + eq(cset1, cset2) + eq(cset2, cset1) + + def test_getaddresses(self): + eq = self.assertEqual + eq(utils.getaddresses(['aperson at dom.ain (Al Person)', + 'Bud Person ']), + [('Al Person', 'aperson at dom.ain'), + ('Bud Person', 'bperson at dom.ain')]) + + def test_getaddresses_nasty(self): + eq = self.assertEqual + eq(utils.getaddresses(['foo: ;']), [('', '')]) + eq(utils.getaddresses( + ['[]*-- =~$']), + [('', ''), ('', ''), ('', '*--')]) + eq(utils.getaddresses( + ['foo: ;', '"Jason R. Mastaler" ']), + [('', ''), ('Jason R. Mastaler', 'jason at dom.ain')]) + + def test_getaddresses_embedded_comment(self): + """Test proper handling of a nested comment""" + eq = self.assertEqual + addrs = utils.getaddresses(['User ((nested comment)) ']) + eq(addrs[0][1], 'foo at bar.com') + + def test_utils_quote_unquote(self): + eq = self.assertEqual + msg = Message() + msg.add_header('content-disposition', 'attachment', + filename='foo\\wacky"name') + eq(msg.get_filename(), 'foo\\wacky"name') + + def test_get_body_encoding_with_bogus_charset(self): + charset = Charset('not a charset') + self.assertEqual(charset.get_body_encoding(), 'base64') + + def test_get_body_encoding_with_uppercase_charset(self): + eq = self.assertEqual + msg = Message() + msg['Content-Type'] = 'text/plain; charset=UTF-8' + eq(msg['content-type'], 'text/plain; charset=UTF-8') + charsets = msg.get_charsets() + eq(len(charsets), 1) + eq(charsets[0], 'utf-8') + charset = Charset(charsets[0]) + eq(charset.get_body_encoding(), 'base64') + msg.set_payload('hello world', charset=charset) + eq(msg.get_payload(), 'aGVsbG8gd29ybGQ=\n') + eq(msg.get_payload(decode=True), 'hello world') + eq(msg['content-transfer-encoding'], 'base64') + # Try another one + msg = Message() + msg['Content-Type'] = 'text/plain; charset="US-ASCII"' + charsets = msg.get_charsets() + eq(len(charsets), 1) + eq(charsets[0], 'us-ascii') + charset = Charset(charsets[0]) + eq(charset.get_body_encoding(), encoders.encode_7or8bit) + msg.set_payload('hello world', charset=charset) + eq(msg.get_payload(), 'hello world') + eq(msg['content-transfer-encoding'], '7bit') + + def test_charsets_case_insensitive(self): + lc = Charset('us-ascii') + uc = Charset('US-ASCII') + self.assertEqual(lc.get_body_encoding(), uc.get_body_encoding()) + + def test_partial_falls_inside_message_delivery_status(self): + eq = self.ndiffAssertEqual + # The Parser interface provides chunks of data to FeedParser in 8192 + # byte gulps. SF bug #1076485 found one of those chunks inside + # message/delivery-status header block, which triggered an + # unreadline() of NeedMoreData. + msg = self._msgobj('msg_43.txt') + sfp = StringIO() + iterators._structure(msg, sfp) + eq(sfp.getvalue(), """\ +multipart/report + text/plain + message/delivery-status + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/plain + text/rfc822-headers +""") + + + +# Test the iterator/generators +class TestIterators(TestEmailBase): + def test_body_line_iterator(self): + eq = self.assertEqual + neq = self.ndiffAssertEqual + # First a simple non-multipart message + msg = self._msgobj('msg_01.txt') + it = iterators.body_line_iterator(msg) + lines = list(it) + eq(len(lines), 6) + neq(EMPTYSTRING.join(lines), msg.get_payload()) + # Now a more complicated multipart + msg = self._msgobj('msg_02.txt') + it = iterators.body_line_iterator(msg) + lines = list(it) + eq(len(lines), 43) + fp = openfile('msg_19.txt') + try: + neq(EMPTYSTRING.join(lines), fp.read()) + finally: + fp.close() + + def test_typed_subpart_iterator(self): + eq = self.assertEqual + msg = self._msgobj('msg_04.txt') + it = iterators.typed_subpart_iterator(msg, 'text') + lines = [] + subparts = 0 + for subpart in it: + subparts += 1 + lines.append(subpart.get_payload()) + eq(subparts, 2) + eq(EMPTYSTRING.join(lines), """\ +a simple kind of mirror +to reflect upon our own +a simple kind of mirror +to reflect upon our own +""") + + def test_typed_subpart_iterator_default_type(self): + eq = self.assertEqual + msg = self._msgobj('msg_03.txt') + it = iterators.typed_subpart_iterator(msg, 'text', 'plain') + lines = [] + subparts = 0 + for subpart in it: + subparts += 1 + lines.append(subpart.get_payload()) + eq(subparts, 1) + eq(EMPTYSTRING.join(lines), """\ + +Hi, + +Do you like this message? + +-Me +""") + + + +class TestParsers(TestEmailBase): + def test_header_parser(self): + eq = self.assertEqual + # Parse only the headers of a complex multipart MIME document + fp = openfile('msg_02.txt') + try: + msg = HeaderParser().parse(fp) + finally: + fp.close() + eq(msg['from'], 'ppp-request at zzz.org') + eq(msg['to'], 'ppp at zzz.org') + eq(msg.get_content_type(), 'multipart/mixed') + self.assertFalse(msg.is_multipart()) + self.assertTrue(isinstance(msg.get_payload(), str)) + + def test_whitespace_continuation(self): + eq = self.assertEqual + # This message contains a line after the Subject: header that has only + # whitespace, but it is not empty! + msg = email.message_from_string("""\ +From: aperson at dom.ain +To: bperson at dom.ain +Subject: the next line has a space on it +\x20 +Date: Mon, 8 Apr 2002 15:09:19 -0400 +Message-ID: spam + +Here's the message body +""") + eq(msg['subject'], 'the next line has a space on it\n ') + eq(msg['message-id'], 'spam') + eq(msg.get_payload(), "Here's the message body\n") + + def test_whitespace_continuation_last_header(self): + eq = self.assertEqual + # Like the previous test, but the subject line is the last + # header. + msg = email.message_from_string("""\ +From: aperson at dom.ain +To: bperson at dom.ain +Date: Mon, 8 Apr 2002 15:09:19 -0400 +Message-ID: spam +Subject: the next line has a space on it +\x20 + +Here's the message body +""") + eq(msg['subject'], 'the next line has a space on it\n ') + eq(msg['message-id'], 'spam') + eq(msg.get_payload(), "Here's the message body\n") + + def test_crlf_separation(self): + eq = self.assertEqual + fp = openfile('msg_26.txt', mode='rb') + try: + msg = Parser().parse(fp) + finally: + fp.close() + eq(len(msg.get_payload()), 2) + part1 = msg.get_payload(0) + eq(part1.get_content_type(), 'text/plain') + eq(part1.get_payload(), 'Simple email with attachment.\r\n\r\n') + part2 = msg.get_payload(1) + eq(part2.get_content_type(), 'application/riscos') + + def test_multipart_digest_with_extra_mime_headers(self): + eq = self.assertEqual + neq = self.ndiffAssertEqual + fp = openfile('msg_28.txt') + try: + msg = email.message_from_file(fp) + finally: + fp.close() + # Structure is: + # multipart/digest + # message/rfc822 + # text/plain + # message/rfc822 + # text/plain + eq(msg.is_multipart(), 1) + eq(len(msg.get_payload()), 2) + part1 = msg.get_payload(0) + eq(part1.get_content_type(), 'message/rfc822') + eq(part1.is_multipart(), 1) + eq(len(part1.get_payload()), 1) + part1a = part1.get_payload(0) + eq(part1a.is_multipart(), 0) + eq(part1a.get_content_type(), 'text/plain') + neq(part1a.get_payload(), 'message 1\n') + # next message/rfc822 + part2 = msg.get_payload(1) + eq(part2.get_content_type(), 'message/rfc822') + eq(part2.is_multipart(), 1) + eq(len(part2.get_payload()), 1) + part2a = part2.get_payload(0) + eq(part2a.is_multipart(), 0) + eq(part2a.get_content_type(), 'text/plain') + neq(part2a.get_payload(), 'message 2\n') + + def test_three_lines(self): + # A bug report by Andrew McNamara + lines = ['From: Andrew Person From', 'From']) + eq(msg.get_payload(), 'body') + + def test_rfc2822_space_not_allowed_in_header(self): + eq = self.assertEqual + m = '>From foo at example.com 11:25:53\nFrom: bar\n!"#QUX;~: zoo\n\nbody' + msg = email.message_from_string(m) + eq(len(msg.keys()), 0) + + def test_rfc2822_one_character_header(self): + eq = self.assertEqual + m = 'A: first header\nB: second header\nCC: third header\n\nbody' + msg = email.message_from_string(m) + headers = msg.keys() + headers.sort() + eq(headers, ['A', 'B', 'CC']) + eq(msg.get_payload(), 'body') + + + +class TestBase64(unittest.TestCase): + def test_len(self): + eq = self.assertEqual + eq(base64mime.base64_len('hello'), + len(base64mime.encode('hello', eol=''))) + for size in range(15): + if size == 0 : bsize = 0 + elif size <= 3 : bsize = 4 + elif size <= 6 : bsize = 8 + elif size <= 9 : bsize = 12 + elif size <= 12: bsize = 16 + else : bsize = 20 + eq(base64mime.base64_len('x'*size), bsize) + + def test_decode(self): + eq = self.assertEqual + eq(base64mime.decode(''), '') + eq(base64mime.decode('aGVsbG8='), 'hello') + eq(base64mime.decode('aGVsbG8=', 'X'), 'hello') + eq(base64mime.decode('aGVsbG8NCndvcmxk\n', 'X'), 'helloXworld') + + def test_encode(self): + eq = self.assertEqual + eq(base64mime.encode(''), '') + eq(base64mime.encode('hello'), 'aGVsbG8=\n') + # Test the binary flag + eq(base64mime.encode('hello\n'), 'aGVsbG8K\n') + eq(base64mime.encode('hello\n', 0), 'aGVsbG8NCg==\n') + # Test the maxlinelen arg + eq(base64mime.encode('xxxx ' * 20, maxlinelen=40), """\ +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg +eHh4eCB4eHh4IA== +""") + # Test the eol argument + eq(base64mime.encode('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\ +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r +eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r +eHh4eCB4eHh4IA==\r +""") + + def test_header_encode(self): + eq = self.assertEqual + he = base64mime.header_encode + eq(he('hello'), '=?iso-8859-1?b?aGVsbG8=?=') + eq(he('hello\nworld'), '=?iso-8859-1?b?aGVsbG8NCndvcmxk?=') + # Test the charset option + eq(he('hello', charset='iso-8859-2'), '=?iso-8859-2?b?aGVsbG8=?=') + # Test the keep_eols flag + eq(he('hello\nworld', keep_eols=True), + '=?iso-8859-1?b?aGVsbG8Kd29ybGQ=?=') + # Test the maxlinelen argument + eq(he('xxxx ' * 20, maxlinelen=40), """\ +=?iso-8859-1?b?eHh4eCB4eHh4IHh4eHggeHg=?= + =?iso-8859-1?b?eHggeHh4eCB4eHh4IHh4eHg=?= + =?iso-8859-1?b?IHh4eHggeHh4eCB4eHh4IHg=?= + =?iso-8859-1?b?eHh4IHh4eHggeHh4eCB4eHg=?= + =?iso-8859-1?b?eCB4eHh4IHh4eHggeHh4eCA=?= + =?iso-8859-1?b?eHh4eCB4eHh4IHh4eHgg?=""") + # Test the eol argument + eq(he('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\ +=?iso-8859-1?b?eHh4eCB4eHh4IHh4eHggeHg=?=\r + =?iso-8859-1?b?eHggeHh4eCB4eHh4IHh4eHg=?=\r + =?iso-8859-1?b?IHh4eHggeHh4eCB4eHh4IHg=?=\r + =?iso-8859-1?b?eHh4IHh4eHggeHh4eCB4eHg=?=\r + =?iso-8859-1?b?eCB4eHh4IHh4eHggeHh4eCA=?=\r + =?iso-8859-1?b?eHh4eCB4eHh4IHh4eHgg?=""") + + + +class TestQuopri(unittest.TestCase): + def setUp(self): + self.hlit = [chr(x) for x in range(ord('a'), ord('z')+1)] + \ + [chr(x) for x in range(ord('A'), ord('Z')+1)] + \ + [chr(x) for x in range(ord('0'), ord('9')+1)] + \ + ['!', '*', '+', '-', '/', ' '] + self.hnon = [chr(x) for x in range(256) if chr(x) not in self.hlit] + assert len(self.hlit) + len(self.hnon) == 256 + self.blit = [chr(x) for x in range(ord(' '), ord('~')+1)] + ['\t'] + self.blit.remove('=') + self.bnon = [chr(x) for x in range(256) if chr(x) not in self.blit] + assert len(self.blit) + len(self.bnon) == 256 + + def test_header_quopri_check(self): + for c in self.hlit: + self.assertFalse(quoprimime.header_quopri_check(c)) + for c in self.hnon: + self.assertTrue(quoprimime.header_quopri_check(c)) + + def test_body_quopri_check(self): + for c in self.blit: + self.assertFalse(quoprimime.body_quopri_check(c)) + for c in self.bnon: + self.assertTrue(quoprimime.body_quopri_check(c)) + + def test_header_quopri_len(self): + eq = self.assertEqual + hql = quoprimime.header_quopri_len + enc = quoprimime.header_encode + for s in ('hello', 'h at e@l at l@o@'): + # Empty charset and no line-endings. 7 == RFC chrome + eq(hql(s), len(enc(s, charset='', eol=''))-7) + for c in self.hlit: + eq(hql(c), 1) + for c in self.hnon: + eq(hql(c), 3) + + def test_body_quopri_len(self): + eq = self.assertEqual + bql = quoprimime.body_quopri_len + for c in self.blit: + eq(bql(c), 1) + for c in self.bnon: + eq(bql(c), 3) + + def test_quote_unquote_idempotent(self): + for x in range(256): + c = chr(x) + self.assertEqual(quoprimime.unquote(quoprimime.quote(c)), c) + + def test_header_encode(self): + eq = self.assertEqual + he = quoprimime.header_encode + eq(he('hello'), '=?iso-8859-1?q?hello?=') + eq(he('hello\nworld'), '=?iso-8859-1?q?hello=0D=0Aworld?=') + # Test the charset option + eq(he('hello', charset='iso-8859-2'), '=?iso-8859-2?q?hello?=') + # Test the keep_eols flag + eq(he('hello\nworld', keep_eols=True), '=?iso-8859-1?q?hello=0Aworld?=') + # Test a non-ASCII character + eq(he('hello\xc7there'), '=?iso-8859-1?q?hello=C7there?=') + # Test the maxlinelen argument + eq(he('xxxx ' * 20, maxlinelen=40), """\ +=?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xx?= + =?iso-8859-1?q?xx_xxxx_xxxx_xxxx_xxxx?= + =?iso-8859-1?q?_xxxx_xxxx_xxxx_xxxx_x?= + =?iso-8859-1?q?xxx_xxxx_xxxx_xxxx_xxx?= + =?iso-8859-1?q?x_xxxx_xxxx_?=""") + # Test the eol argument + eq(he('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\ +=?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xx?=\r + =?iso-8859-1?q?xx_xxxx_xxxx_xxxx_xxxx?=\r + =?iso-8859-1?q?_xxxx_xxxx_xxxx_xxxx_x?=\r + =?iso-8859-1?q?xxx_xxxx_xxxx_xxxx_xxx?=\r + =?iso-8859-1?q?x_xxxx_xxxx_?=""") + + def test_decode(self): + eq = self.assertEqual + eq(quoprimime.decode(''), '') + eq(quoprimime.decode('hello'), 'hello') + eq(quoprimime.decode('hello', 'X'), 'hello') + eq(quoprimime.decode('hello\nworld', 'X'), 'helloXworld') + + def test_encode(self): + eq = self.assertEqual + eq(quoprimime.encode(''), '') + eq(quoprimime.encode('hello'), 'hello') + # Test the binary flag + eq(quoprimime.encode('hello\r\nworld'), 'hello\nworld') + eq(quoprimime.encode('hello\r\nworld', 0), 'hello\nworld') + # Test the maxlinelen arg + eq(quoprimime.encode('xxxx ' * 20, maxlinelen=40), """\ +xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx= + xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx= +x xxxx xxxx xxxx xxxx=20""") + # Test the eol argument + eq(quoprimime.encode('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\ +xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx=\r + xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx=\r +x xxxx xxxx xxxx xxxx=20""") + eq(quoprimime.encode("""\ +one line + +two line"""), """\ +one line + +two line""") + + + +# Test the Charset class +class TestCharset(unittest.TestCase): + def tearDown(self): + from email import charset as CharsetModule + try: + del CharsetModule.CHARSETS['fake'] + except KeyError: + pass + + def test_idempotent(self): + eq = self.assertEqual + # Make sure us-ascii = no Unicode conversion + c = Charset('us-ascii') + s = 'Hello World!' + sp = c.to_splittable(s) + eq(s, c.from_splittable(sp)) + # test 8-bit idempotency with us-ascii + s = '\xa4\xa2\xa4\xa4\xa4\xa6\xa4\xa8\xa4\xaa' + sp = c.to_splittable(s) + eq(s, c.from_splittable(sp)) + + def test_body_encode(self): + eq = self.assertEqual + # Try a charset with QP body encoding + c = Charset('iso-8859-1') + eq('hello w=F6rld', c.body_encode('hello w\xf6rld')) + # Try a charset with Base64 body encoding + c = Charset('utf-8') + eq('aGVsbG8gd29ybGQ=\n', c.body_encode('hello world')) + # Try a charset with None body encoding + c = Charset('us-ascii') + eq('hello world', c.body_encode('hello world')) + # Try the convert argument, where input codec != output codec + c = Charset('euc-jp') + # With apologies to Tokio Kikuchi ;) + if not is_jython: + # TODO Jython with its Java-based codecs does not + # currently support trailing bytes in CJK texts + try: + eq('\x1b$B5FCO;~IW\x1b(B', + c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7')) + eq('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7', + c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7', False)) + except LookupError: + # We probably don't have the Japanese codecs installed + pass + # Testing SF bug #625509, which we have to fake, since there are no + # built-in encodings where the header encoding is QP but the body + # encoding is not. + from email import charset as CharsetModule + CharsetModule.add_charset('fake', CharsetModule.QP, None) + c = Charset('fake') + eq('hello w\xf6rld', c.body_encode('hello w\xf6rld')) + + def test_unicode_charset_name(self): + charset = Charset(u'us-ascii') + self.assertEqual(str(charset), 'us-ascii') + self.assertRaises(errors.CharsetError, Charset, 'asc\xffii') + + + +# Test multilingual MIME headers. +class TestHeader(TestEmailBase): + def test_simple(self): + eq = self.ndiffAssertEqual + h = Header('Hello World!') + eq(h.encode(), 'Hello World!') + h.append(' Goodbye World!') + eq(h.encode(), 'Hello World! Goodbye World!') + + def test_simple_surprise(self): + eq = self.ndiffAssertEqual + h = Header('Hello World!') + eq(h.encode(), 'Hello World!') + h.append('Goodbye World!') + eq(h.encode(), 'Hello World! Goodbye World!') + + def test_header_needs_no_decoding(self): + h = 'no decoding needed' + self.assertEqual(decode_header(h), [(h, None)]) + + def test_long(self): + h = Header("I am the very model of a modern Major-General; I've information vegetable, animal, and mineral; I know the kings of England, and I quote the fights historical from Marathon to Waterloo, in order categorical; I'm very well acquainted, too, with matters mathematical; I understand equations, both the simple and quadratical; about binomial theorem I'm teeming with a lot o' news, with many cheerful facts about the square of the hypotenuse.", + maxlinelen=76) + for l in h.encode(splitchars=' ').split('\n '): + self.assertTrue(len(l) <= 76) + + def test_multilingual(self): + eq = self.ndiffAssertEqual + g = Charset("iso-8859-1") + cz = Charset("iso-8859-2") + utf8 = Charset("utf-8") + g_head = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. " + cz_head = "Finan\xe8ni metropole se hroutily pod tlakem jejich d\xf9vtipu.. " + utf8_head = u"\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00\u90e8\u306f\u30c9\u30a4\u30c4\u8a9e\u3067\u3059\u304c\u3001\u3042\u3068\u306f\u3067\u305f\u3089\u3081\u3067\u3059\u3002\u5b9f\u969b\u306b\u306f\u300cWenn ist das Nunstuck git und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt gersput.\u300d\u3068\u8a00\u3063\u3066\u3044\u307e\u3059\u3002".encode("utf-8") + h = Header(g_head, g) + h.append(cz_head, cz) + h.append(utf8_head, utf8) + enc = h.encode() + eq(enc, """\ +=?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerderband_ko?= + =?iso-8859-1?q?mfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndischen_Wan?= + =?iso-8859-1?q?dgem=E4lden_vorbei=2C_gegen_die_rotierenden_Klingen_bef=F6?= + =?iso-8859-1?q?rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_metropole_se_hroutily?= + =?iso-8859-2?q?_pod_tlakem_jejich_d=F9vtipu=2E=2E_?= =?utf-8?b?5q2j56K6?= + =?utf-8?b?44Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE44G+44Gb44KT44CC?= + =?utf-8?b?5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB44GC44Go44Gv44Gn?= + =?utf-8?b?44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CMV2VubiBpc3QgZGFz?= + =?utf-8?q?_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das_Oder_die_Fl?= + =?utf-8?b?aXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBo+OBpuOBhOOBvuOBmQ==?= + =?utf-8?b?44CC?=""") + eq(decode_header(enc), + [(g_head, "iso-8859-1"), (cz_head, "iso-8859-2"), + (utf8_head, "utf-8")]) + ustr = unicode(h) + eq(ustr.encode('utf-8'), + 'Die Mieter treten hier ein werden mit einem Foerderband ' + 'komfortabel den Korridor entlang, an s\xc3\xbcdl\xc3\xbcndischen ' + 'Wandgem\xc3\xa4lden vorbei, gegen die rotierenden Klingen ' + 'bef\xc3\xb6rdert. Finan\xc4\x8dni metropole se hroutily pod ' + 'tlakem jejich d\xc5\xafvtipu.. \xe6\xad\xa3\xe7\xa2\xba\xe3\x81' + '\xab\xe8\xa8\x80\xe3\x81\x86\xe3\x81\xa8\xe7\xbf\xbb\xe8\xa8\xb3' + '\xe3\x81\xaf\xe3\x81\x95\xe3\x82\x8c\xe3\x81\xa6\xe3\x81\x84\xe3' + '\x81\xbe\xe3\x81\x9b\xe3\x82\x93\xe3\x80\x82\xe4\xb8\x80\xe9\x83' + '\xa8\xe3\x81\xaf\xe3\x83\x89\xe3\x82\xa4\xe3\x83\x84\xe8\xaa\x9e' + '\xe3\x81\xa7\xe3\x81\x99\xe3\x81\x8c\xe3\x80\x81\xe3\x81\x82\xe3' + '\x81\xa8\xe3\x81\xaf\xe3\x81\xa7\xe3\x81\x9f\xe3\x82\x89\xe3\x82' + '\x81\xe3\x81\xa7\xe3\x81\x99\xe3\x80\x82\xe5\xae\x9f\xe9\x9a\x9b' + '\xe3\x81\xab\xe3\x81\xaf\xe3\x80\x8cWenn ist das Nunstuck git ' + 'und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt ' + 'gersput.\xe3\x80\x8d\xe3\x81\xa8\xe8\xa8\x80\xe3\x81\xa3\xe3\x81' + '\xa6\xe3\x81\x84\xe3\x81\xbe\xe3\x81\x99\xe3\x80\x82') + # Test make_header() + newh = make_header(decode_header(enc)) + eq(newh, enc) + + def test_header_ctor_default_args(self): + eq = self.ndiffAssertEqual + h = Header() + eq(h, '') + h.append('foo', Charset('iso-8859-1')) + eq(h, '=?iso-8859-1?q?foo?=') + + def test_explicit_maxlinelen(self): + eq = self.ndiffAssertEqual + hstr = 'A very long line that must get split to something other than at the 76th character boundary to test the non-default behavior' + h = Header(hstr) + eq(h.encode(), '''\ +A very long line that must get split to something other than at the 76th + character boundary to test the non-default behavior''') + h = Header(hstr, header_name='Subject') + eq(h.encode(), '''\ +A very long line that must get split to something other than at the + 76th character boundary to test the non-default behavior''') + h = Header(hstr, maxlinelen=1024, header_name='Subject') + eq(h.encode(), hstr) + + def test_us_ascii_header(self): + eq = self.assertEqual + s = 'hello' + x = decode_header(s) + eq(x, [('hello', None)]) + h = make_header(x) + eq(s, h.encode()) + + def test_string_charset(self): + eq = self.assertEqual + h = Header() + h.append('hello', 'iso-8859-1') + eq(h, '=?iso-8859-1?q?hello?=') + +## def test_unicode_error(self): +## raises = self.assertRaises +## raises(UnicodeError, Header, u'[P\xf6stal]', 'us-ascii') +## raises(UnicodeError, Header, '[P\xf6stal]', 'us-ascii') +## h = Header() +## raises(UnicodeError, h.append, u'[P\xf6stal]', 'us-ascii') +## raises(UnicodeError, h.append, '[P\xf6stal]', 'us-ascii') +## raises(UnicodeError, Header, u'\u83ca\u5730\u6642\u592b', 'iso-8859-1') + + def test_utf8_shortest(self): + eq = self.assertEqual + h = Header(u'p\xf6stal', 'utf-8') + eq(h.encode(), '=?utf-8?q?p=C3=B6stal?=') + h = Header(u'\u83ca\u5730\u6642\u592b', 'utf-8') + eq(h.encode(), '=?utf-8?b?6I+K5Zyw5pmC5aSr?=') + + def test_bad_8bit_header(self): + raises = self.assertRaises + eq = self.assertEqual + x = 'Ynwp4dUEbay Auction Semiar- No Charge \x96 Earn Big' + raises(UnicodeError, Header, x) + h = Header() + raises(UnicodeError, h.append, x) + eq(str(Header(x, errors='replace')), x) + h.append(x, errors='replace') + eq(str(h), x) + + def test_encoded_adjacent_nonencoded(self): + eq = self.assertEqual + h = Header() + h.append('hello', 'iso-8859-1') + h.append('world') + s = h.encode() + eq(s, '=?iso-8859-1?q?hello?= world') + h = make_header(decode_header(s)) + eq(h.encode(), s) + + def test_whitespace_eater(self): + eq = self.assertEqual + s = 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztk=?= =?koi8-r?q?=CA?= zz.' + parts = decode_header(s) + eq(parts, [('Subject:', None), ('\xf0\xd2\xcf\xd7\xc5\xd2\xcb\xc1 \xce\xc1 \xc6\xc9\xce\xc1\xcc\xd8\xce\xd9\xca', 'koi8-r'), ('zz.', None)]) + hdr = make_header(parts) + eq(hdr.encode(), + 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztnK?= zz.') + + def test_broken_base64_header(self): + raises = self.assertRaises + s = 'Subject: =?EUC-KR?B?CSixpLDtKSC/7Liuvsax4iC6uLmwMcijIKHaILzSwd/H0SC8+LCjwLsgv7W/+Mj3I ?=' + raises(errors.HeaderParseError, decode_header, s) + + + +# Test RFC 2231 header parameters (en/de)coding +class TestRFC2231(TestEmailBase): + def test_get_param(self): + eq = self.assertEqual + msg = self._msgobj('msg_29.txt') + eq(msg.get_param('title'), + ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!')) + eq(msg.get_param('title', unquote=False), + ('us-ascii', 'en', '"This is even more ***fun*** isn\'t it!"')) + + def test_set_param(self): + eq = self.assertEqual + msg = Message() + msg.set_param('title', 'This is even more ***fun*** isn\'t it!', + charset='us-ascii') + eq(msg.get_param('title'), + ('us-ascii', '', 'This is even more ***fun*** isn\'t it!')) + msg.set_param('title', 'This is even more ***fun*** isn\'t it!', + charset='us-ascii', language='en') + eq(msg.get_param('title'), + ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!')) + msg = self._msgobj('msg_01.txt') + msg.set_param('title', 'This is even more ***fun*** isn\'t it!', + charset='us-ascii', language='en') + self.ndiffAssertEqual(msg.as_string(), """\ +Return-Path: +Delivered-To: bbb at zzz.org +Received: by mail.zzz.org (Postfix, from userid 889) + id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT) +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit +Message-ID: <15090.61304.110929.45684 at aaa.zzz.org> +From: bbb at ddd.com (John X. Doe) +To: bbb at zzz.org +Subject: This is a test message +Date: Fri, 4 May 2001 14:05:44 -0400 +Content-Type: text/plain; charset=us-ascii; + title*="us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21" + + +Hi, + +Do you like this message? + +-Me +""") + + def test_del_param(self): + eq = self.ndiffAssertEqual + msg = self._msgobj('msg_01.txt') + msg.set_param('foo', 'bar', charset='us-ascii', language='en') + msg.set_param('title', 'This is even more ***fun*** isn\'t it!', + charset='us-ascii', language='en') + msg.del_param('foo', header='Content-Type') + eq(msg.as_string(), """\ +Return-Path: +Delivered-To: bbb at zzz.org +Received: by mail.zzz.org (Postfix, from userid 889) + id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT) +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit +Message-ID: <15090.61304.110929.45684 at aaa.zzz.org> +From: bbb at ddd.com (John X. Doe) +To: bbb at zzz.org +Subject: This is a test message +Date: Fri, 4 May 2001 14:05:44 -0400 +Content-Type: text/plain; charset="us-ascii"; + title*="us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21" + + +Hi, + +Do you like this message? + +-Me +""") + + def test_rfc2231_get_content_charset(self): + eq = self.assertEqual + msg = self._msgobj('msg_32.txt') + eq(msg.get_content_charset(), 'us-ascii') + + def test_rfc2231_no_language_or_charset(self): + m = '''\ +Content-Transfer-Encoding: 8bit +Content-Disposition: inline; filename="file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm" +Content-Type: text/html; NAME*0=file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEM; NAME*1=P_nsmail.htm + +''' + msg = email.message_from_string(m) + param = msg.get_param('NAME') + self.assertFalse(isinstance(param, tuple)) + self.assertEqual( + param, + 'file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm') + + def test_rfc2231_no_language_or_charset_in_filename(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0*="''This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), + 'This is even more ***fun*** is it not.pdf') + + def test_rfc2231_no_language_or_charset_in_filename_encoded(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0*="''This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), + 'This is even more ***fun*** is it not.pdf') + + def test_rfc2231_partly_encoded(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0="''This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual( + msg.get_filename(), + 'This%20is%20even%20more%20***fun*** is it not.pdf') + + def test_rfc2231_partly_nonencoded(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0="This%20is%20even%20more%20"; +\tfilename*1="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual( + msg.get_filename(), + 'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20is it not.pdf') + + def test_rfc2231_no_language_or_charset_in_boundary(self): + m = '''\ +Content-Type: multipart/alternative; +\tboundary*0*="''This%20is%20even%20more%20"; +\tboundary*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tboundary*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_boundary(), + 'This is even more ***fun*** is it not.pdf') + + def test_rfc2231_no_language_or_charset_in_charset(self): + # This is a nonsensical charset value, but tests the code anyway + m = '''\ +Content-Type: text/plain; +\tcharset*0*="This%20is%20even%20more%20"; +\tcharset*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tcharset*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_content_charset(), + 'this is even more ***fun*** is it not.pdf') + + def test_rfc2231_bad_encoding_in_filename(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0*="bogus'xx'This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2="is it not.pdf" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), + 'This is even more ***fun*** is it not.pdf') + + def test_rfc2231_bad_encoding_in_charset(self): + m = """\ +Content-Type: text/plain; charset*=bogus''utf-8%E2%80%9D + +""" + msg = email.message_from_string(m) + # This should return None because non-ascii characters in the charset + # are not allowed. + self.assertEqual(msg.get_content_charset(), None) + + def test_rfc2231_bad_character_in_charset(self): + m = """\ +Content-Type: text/plain; charset*=ascii''utf-8%E2%80%9D + +""" + msg = email.message_from_string(m) + # This should return None because non-ascii characters in the charset + # are not allowed. + self.assertEqual(msg.get_content_charset(), None) + + def test_rfc2231_bad_character_in_filename(self): + m = '''\ +Content-Disposition: inline; +\tfilename*0*="ascii'xx'This%20is%20even%20more%20"; +\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20"; +\tfilename*2*="is it not.pdf%E2" + +''' + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), + u'This is even more ***fun*** is it not.pdf\ufffd') + + def test_rfc2231_unknown_encoding(self): + m = """\ +Content-Transfer-Encoding: 8bit +Content-Disposition: inline; filename*=X-UNKNOWN''myfile.txt + +""" + msg = email.message_from_string(m) + self.assertEqual(msg.get_filename(), 'myfile.txt') + + def test_rfc2231_single_tick_in_filename_extended(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; +\tname*0*=\"Frank's\"; name*1*=\" Document\" + +""" + msg = email.message_from_string(m) + charset, language, s = msg.get_param('name') + eq(charset, None) + eq(language, None) + eq(s, "Frank's Document") + + def test_rfc2231_single_tick_in_filename(self): + m = """\ +Content-Type: application/x-foo; name*0=\"Frank's\"; name*1=\" Document\" + +""" + msg = email.message_from_string(m) + param = msg.get_param('name') + self.assertFalse(isinstance(param, tuple)) + self.assertEqual(param, "Frank's Document") + + def test_rfc2231_tick_attack_extended(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; +\tname*0*=\"us-ascii'en-us'Frank's\"; name*1*=\" Document\" + +""" + msg = email.message_from_string(m) + charset, language, s = msg.get_param('name') + eq(charset, 'us-ascii') + eq(language, 'en-us') + eq(s, "Frank's Document") + + def test_rfc2231_tick_attack(self): + m = """\ +Content-Type: application/x-foo; +\tname*0=\"us-ascii'en-us'Frank's\"; name*1=\" Document\" + +""" + msg = email.message_from_string(m) + param = msg.get_param('name') + self.assertFalse(isinstance(param, tuple)) + self.assertEqual(param, "us-ascii'en-us'Frank's Document") + + def test_rfc2231_no_extended_values(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; name=\"Frank's Document\" + +""" + msg = email.message_from_string(m) + eq(msg.get_param('name'), "Frank's Document") + + def test_rfc2231_encoded_then_unencoded_segments(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; +\tname*0*=\"us-ascii'en-us'My\"; +\tname*1=\" Document\"; +\tname*2*=\" For You\" + +""" + msg = email.message_from_string(m) + charset, language, s = msg.get_param('name') + eq(charset, 'us-ascii') + eq(language, 'en-us') + eq(s, 'My Document For You') + + def test_rfc2231_unencoded_then_encoded_segments(self): + eq = self.assertEqual + m = """\ +Content-Type: application/x-foo; +\tname*0=\"us-ascii'en-us'My\"; +\tname*1*=\" Document\"; +\tname*2*=\" For You\" + +""" + msg = email.message_from_string(m) + charset, language, s = msg.get_param('name') + eq(charset, 'us-ascii') + eq(language, 'en-us') + eq(s, 'My Document For You') + + + +def _testclasses(): + mod = sys.modules[__name__] + return [getattr(mod, name) for name in dir(mod) if name.startswith('Test')] + + +def suite(): + suite = unittest.TestSuite() + for testclass in _testclasses(): + suite.addTest(unittest.makeSuite(testclass)) + return suite + + +def test_main(): + for testclass in _testclasses(): + run_unittest(testclass) + + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 9 07:53:00 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 09 Dec 2014 06:53:00 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_object=28=29_should_support?= =?utf-8?q?_rich_comparison_against_other_objects?= Message-ID: <20141209065258.125275.19155@psf.io> https://hg.python.org/jython/rev/e9c37cf662cf changeset: 7442:e9c37cf662cf user: Jim Baker date: Mon Dec 08 23:45:31 2014 -0700 summary: object() should support rich comparison against other objects Removes a mysterious FIXME skip in test_cgi due to the use of object() as a sentinel in a test against a list of values. files: Lib/test/test_cgi.py | 395 -------------- Lib/test/test_cmp_jy.py | 43 +- src/org/python/core/PyDictionary.java | 4 +- src/org/python/core/PySequence.java | 12 +- 4 files changed, 50 insertions(+), 404 deletions(-) diff --git a/Lib/test/test_cgi.py b/Lib/test/test_cgi.py deleted file mode 100644 --- a/Lib/test/test_cgi.py +++ /dev/null @@ -1,395 +0,0 @@ -from test.test_support import run_unittest, check_warnings -import cgi -import os -import sys -import tempfile -import unittest - -class HackedSysModule: - # The regression test will have real values in sys.argv, which - # will completely confuse the test of the cgi module - argv = [] - stdin = sys.stdin - -cgi.sys = HackedSysModule() - -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO - -class ComparableException: - def __init__(self, err): - self.err = err - - def __str__(self): - return str(self.err) - - def __cmp__(self, anExc): - if not isinstance(anExc, Exception): - return -1 - x = cmp(self.err.__class__, anExc.__class__) - if x != 0: - return x - return cmp(self.err.args, anExc.args) - - def __getattr__(self, attr): - return getattr(self.err, attr) - -def do_test(buf, method): - env = {} - if method == "GET": - fp = None - env['REQUEST_METHOD'] = 'GET' - env['QUERY_STRING'] = buf - elif method == "POST": - fp = StringIO(buf) - env['REQUEST_METHOD'] = 'POST' - env['CONTENT_TYPE'] = 'application/x-www-form-urlencoded' - env['CONTENT_LENGTH'] = str(len(buf)) - else: - raise ValueError, "unknown method: %s" % method - try: - return cgi.parse(fp, env, strict_parsing=1) - except StandardError, err: - return ComparableException(err) - -parse_strict_test_cases = [ - ("", ValueError("bad query field: ''")), - ("&", ValueError("bad query field: ''")), - ("&&", ValueError("bad query field: ''")), - (";", ValueError("bad query field: ''")), - (";&;", ValueError("bad query field: ''")), - # Should the next few really be valid? - ("=", {}), - ("=&=", {}), - ("=;=", {}), - # This rest seem to make sense - ("=a", {'': ['a']}), - ("&=a", ValueError("bad query field: ''")), - ("=a&", ValueError("bad query field: ''")), - ("=&a", ValueError("bad query field: 'a'")), - ("b=a", {'b': ['a']}), - ("b+=a", {'b ': ['a']}), - ("a=b=a", {'a': ['b=a']}), - ("a=+b=a", {'a': [' b=a']}), - ("&b=a", ValueError("bad query field: ''")), - ("b&=a", ValueError("bad query field: 'b'")), -#FIXME: None of these are working in Jython -# ("a=a+b&b=b+c", {'a': ['a b'], 'b': ['b c']}), -# ("a=a+b&a=b+a", {'a': ['a b', 'b a']}), -# ("x=1&y=2.0&z=2-3.%2b0", {'x': ['1'], 'y': ['2.0'], 'z': ['2-3.+0']}), -# ("x=1;y=2.0&z=2-3.%2b0", {'x': ['1'], 'y': ['2.0'], 'z': ['2-3.+0']}), -# ("x=1;y=2.0;z=2-3.%2b0", {'x': ['1'], 'y': ['2.0'], 'z': ['2-3.+0']}), -# ("Hbc5161168c542333633315dee1182227:key_store_seqid=400006&cuyer=r&view=bustomer&order_id=0bb2e248638833d48cb7fed300000f1b&expire=964546263&lobale=en-US&kid=130003.300038&ss=env", -# {'Hbc5161168c542333633315dee1182227:key_store_seqid': ['400006'], -# 'cuyer': ['r'], -# 'expire': ['964546263'], -# 'kid': ['130003.300038'], -# 'lobale': ['en-US'], -# 'order_id': ['0bb2e248638833d48cb7fed300000f1b'], -# 'ss': ['env'], -# 'view': ['bustomer'], -# }), -# -# ("group_id=5470&set=custom&_assigned_to=31392&_status=1&_category=100&SUBMIT=Browse", -# {'SUBMIT': ['Browse'], -# '_assigned_to': ['31392'], -# '_category': ['100'], -# '_status': ['1'], -# 'group_id': ['5470'], -# 'set': ['custom'], -# }) - ] - -def first_elts(list): - return map(lambda x:x[0], list) - -def first_second_elts(list): - return map(lambda p:(p[0], p[1][0]), list) - -def gen_result(data, environ): - fake_stdin = StringIO(data) - fake_stdin.seek(0) - form = cgi.FieldStorage(fp=fake_stdin, environ=environ) - - result = {} - for k, v in dict(form).items(): - result[k] = isinstance(v, list) and form.getlist(k) or v.value - - return result - -class CgiTests(unittest.TestCase): - - def test_escape(self): - self.assertEqual("test & string", cgi.escape("test & string")) - self.assertEqual("<test string>", cgi.escape("")) - self.assertEqual(""test string"", cgi.escape('"test string"', True)) - - def test_strict(self): - for orig, expect in parse_strict_test_cases: - # Test basic parsing - d = do_test(orig, "GET") - self.assertEqual(d, expect, "Error parsing %s" % repr(orig)) - d = do_test(orig, "POST") - self.assertEqual(d, expect, "Error parsing %s" % repr(orig)) - - env = {'QUERY_STRING': orig} - fcd = cgi.FormContentDict(env) - sd = cgi.SvFormContentDict(env) - fs = cgi.FieldStorage(environ=env) - if isinstance(expect, dict): - # test dict interface - self.assertEqual(len(expect), len(fcd)) - self.assertItemsEqual(expect.keys(), fcd.keys()) - self.assertItemsEqual(expect.values(), fcd.values()) - self.assertItemsEqual(expect.items(), fcd.items()) - self.assertEqual(fcd.get("nonexistent field", "default"), "default") - self.assertEqual(len(sd), len(fs)) - self.assertItemsEqual(sd.keys(), fs.keys()) - self.assertEqual(fs.getvalue("nonexistent field", "default"), "default") - # test individual fields - for key in expect.keys(): - expect_val = expect[key] - self.assertTrue(fcd.has_key(key)) - self.assertItemsEqual(fcd[key], expect[key]) - self.assertEqual(fcd.get(key, "default"), fcd[key]) - self.assertTrue(fs.has_key(key)) - if len(expect_val) > 1: - single_value = 0 - else: - single_value = 1 - try: - val = sd[key] - except IndexError: - self.assertFalse(single_value) - self.assertEqual(fs.getvalue(key), expect_val) - else: - self.assertTrue(single_value) - self.assertEqual(val, expect_val[0]) - self.assertEqual(fs.getvalue(key), expect_val[0]) - self.assertItemsEqual(sd.getlist(key), expect_val) - if single_value: - self.assertItemsEqual(sd.values(), - first_elts(expect.values())) - self.assertItemsEqual(sd.items(), - first_second_elts(expect.items())) - - def test_weird_formcontentdict(self): - # Test the weird FormContentDict classes - env = {'QUERY_STRING': "x=1&y=2.0&z=2-3.%2b0&1=1abc"} - expect = {'x': 1, 'y': 2.0, 'z': '2-3.+0', '1': '1abc'} - d = cgi.InterpFormContentDict(env) - for k, v in expect.items(): - self.assertEqual(d[k], v) - for k, v in d.items(): - self.assertEqual(expect[k], v) - self.assertItemsEqual(expect.values(), d.values()) - - def test_log(self): - cgi.log("Testing") - - cgi.logfp = StringIO() - cgi.initlog("%s", "Testing initlog 1") - cgi.log("%s", "Testing log 2") - self.assertEqual(cgi.logfp.getvalue(), "Testing initlog 1\nTesting log 2\n") - if os.path.exists("/dev/null"): - cgi.logfp = None - cgi.logfile = "/dev/null" - cgi.initlog("%s", "Testing log 3") - cgi.log("Testing log 4") - - def test_fieldstorage_readline(self): - # FieldStorage uses readline, which has the capacity to read all - # contents of the input file into memory; we use readline's size argument - # to prevent that for files that do not contain any newlines in - # non-GET/HEAD requests - class TestReadlineFile: - def __init__(self, file): - self.file = file - self.numcalls = 0 - - def readline(self, size=None): - self.numcalls += 1 - if size: - return self.file.readline(size) - else: - return self.file.readline() - - def __getattr__(self, name): - file = self.__dict__['file'] - a = getattr(file, name) - if not isinstance(a, int): - setattr(self, name, a) - return a - - f = TestReadlineFile(tempfile.TemporaryFile()) - f.write('x' * 256 * 1024) - f.seek(0) - env = {'REQUEST_METHOD':'PUT'} - fs = cgi.FieldStorage(fp=f, environ=env) - # if we're not chunking properly, readline is only called twice - # (by read_binary); if we are chunking properly, it will be called 5 times - # as long as the chunksize is 1 << 16. - self.assertTrue(f.numcalls > 2) - - def test_fieldstorage_multipart(self): - #Test basic FieldStorage multipart parsing - env = {'REQUEST_METHOD':'POST', 'CONTENT_TYPE':'multipart/form-data; boundary=---------------------------721837373350705526688164684', 'CONTENT_LENGTH':'558'} - postdata = """-----------------------------721837373350705526688164684 -Content-Disposition: form-data; name="id" - -1234 ------------------------------721837373350705526688164684 -Content-Disposition: form-data; name="title" - - ------------------------------721837373350705526688164684 -Content-Disposition: form-data; name="file"; filename="test.txt" -Content-Type: text/plain - -Testing 123. - ------------------------------721837373350705526688164684 -Content-Disposition: form-data; name="submit" - - Add\x20 ------------------------------721837373350705526688164684-- -""" - fs = cgi.FieldStorage(fp=StringIO(postdata), environ=env) - self.assertEqual(len(fs.list), 4) - expect = [{'name':'id', 'filename':None, 'value':'1234'}, - {'name':'title', 'filename':None, 'value':''}, - {'name':'file', 'filename':'test.txt','value':'Testing 123.\n'}, - {'name':'submit', 'filename':None, 'value':' Add '}] - for x in range(len(fs.list)): - for k, exp in expect[x].items(): - got = getattr(fs.list[x], k) - self.assertEqual(got, exp) - - _qs_result = { - 'key1': 'value1', - 'key2': ['value2x', 'value2y'], - 'key3': 'value3', - 'key4': 'value4' - } - def testQSAndUrlEncode(self): - data = "key2=value2x&key3=value3&key4=value4" - environ = { - 'CONTENT_LENGTH': str(len(data)), - 'CONTENT_TYPE': 'application/x-www-form-urlencoded', - 'QUERY_STRING': 'key1=value1&key2=value2y', - 'REQUEST_METHOD': 'POST', - } - v = gen_result(data, environ) - self.assertEqual(self._qs_result, v) - - def testQSAndFormData(self): - data = """ ----123 -Content-Disposition: form-data; name="key2" - -value2y ----123 -Content-Disposition: form-data; name="key3" - -value3 ----123 -Content-Disposition: form-data; name="key4" - -value4 ----123-- -""" - environ = { - 'CONTENT_LENGTH': str(len(data)), - 'CONTENT_TYPE': 'multipart/form-data; boundary=-123', - 'QUERY_STRING': 'key1=value1&key2=value2x', - 'REQUEST_METHOD': 'POST', - } - v = gen_result(data, environ) - self.assertEqual(self._qs_result, v) - - def testQSAndFormDataFile(self): - data = """ ----123 -Content-Disposition: form-data; name="key2" - -value2y ----123 -Content-Disposition: form-data; name="key3" - -value3 ----123 -Content-Disposition: form-data; name="key4" - -value4 ----123 -Content-Disposition: form-data; name="upload"; filename="fake.txt" -Content-Type: text/plain - -this is the content of the fake file - ----123-- -""" - environ = { - 'CONTENT_LENGTH': str(len(data)), - 'CONTENT_TYPE': 'multipart/form-data; boundary=-123', - 'QUERY_STRING': 'key1=value1&key2=value2x', - 'REQUEST_METHOD': 'POST', - } - result = self._qs_result.copy() - result.update({ - 'upload': 'this is the content of the fake file\n' - }) - v = gen_result(data, environ) - self.assertEqual(result, v) - - def test_deprecated_parse_qs(self): - # this func is moved to urlparse, this is just a sanity check - with check_warnings(('cgi.parse_qs is deprecated, use urlparse.' - 'parse_qs instead', PendingDeprecationWarning)): - self.assertEqual({'a': ['A1'], 'B': ['B3'], 'b': ['B2']}, - cgi.parse_qs('a=A1&b=B2&B=B3')) - - def test_deprecated_parse_qsl(self): - # this func is moved to urlparse, this is just a sanity check - with check_warnings(('cgi.parse_qsl is deprecated, use urlparse.' - 'parse_qsl instead', PendingDeprecationWarning)): - self.assertEqual([('a', 'A1'), ('b', 'B2'), ('B', 'B3')], - cgi.parse_qsl('a=A1&b=B2&B=B3')) - - def test_parse_header(self): - self.assertEqual( - cgi.parse_header("text/plain"), - ("text/plain", {})) - self.assertEqual( - cgi.parse_header("text/vnd.just.made.this.up ; "), - ("text/vnd.just.made.this.up", {})) - self.assertEqual( - cgi.parse_header("text/plain;charset=us-ascii"), - ("text/plain", {"charset": "us-ascii"})) - self.assertEqual( - cgi.parse_header('text/plain ; charset="us-ascii"'), - ("text/plain", {"charset": "us-ascii"})) - self.assertEqual( - cgi.parse_header('text/plain ; charset="us-ascii"; another=opt'), - ("text/plain", {"charset": "us-ascii", "another": "opt"})) - self.assertEqual( - cgi.parse_header('attachment; filename="silly.txt"'), - ("attachment", {"filename": "silly.txt"})) - self.assertEqual( - cgi.parse_header('attachment; filename="strange;name"'), - ("attachment", {"filename": "strange;name"})) - self.assertEqual( - cgi.parse_header('attachment; filename="strange;name";size=123;'), - ("attachment", {"filename": "strange;name", "size": "123"})) - self.assertEqual( - cgi.parse_header('form-data; name="files"; filename="fo\\"o;bar"'), - ("form-data", {"name": "files", "filename": 'fo"o;bar'})) - - -def test_main(): - run_unittest(CgiTests) - -if __name__ == '__main__': - test_main() diff --git a/Lib/test/test_cmp_jy.py b/Lib/test/test_cmp_jy.py --- a/Lib/test/test_cmp_jy.py +++ b/Lib/test/test_cmp_jy.py @@ -45,6 +45,46 @@ assert not (-1 == 'a') +class ObjectCmp(unittest.TestCase): + def testObjectListCompares(self): + # Also applies to tuple objects given common PySequence implementation + assert not object() == list() + assert object() != list() + assert not list() == object() + assert list() != object() + + # Note that <, > rich comparisons in 2.x are broken by the + # lexicographic ordering of the type **name**. Example: + # 'object' > 'list' + assert not object() < list() + assert not object() <= list() + assert object() > list() + assert object() >= list() + assert list() < object() + assert list() <= object() + assert not list() > object() + assert not list() >= object() + + def testObjectDictCompares(self): + # Also applies to such objects as defaultdict and Counter + assert not object() == dict() + assert object() != dict() + assert not dict() == object() + assert dict() != object() + + # Note that <, > rich comparisons in 2.x are broken by the + # lexicographic ordering of the type **name**. Example: + # 'object' > 'dict' + assert not object() < dict() + assert not object() <= dict() + assert object() > dict() + assert object() >= dict() + assert dict() < object() + assert dict() <= object() + assert not dict() > object() + assert not dict() >= object() + + class CustomCmp(unittest.TestCase): def test___cmp___returns(self): class Foo(object): @@ -83,7 +123,8 @@ UnicodeDerivedCmp, LongDerivedCmp, IntStrCmp, - CustomCmp + ObjectCmp, + CustomCmp, ) diff --git a/src/org/python/core/PyDictionary.java b/src/org/python/core/PyDictionary.java --- a/src/org/python/core/PyDictionary.java +++ b/src/org/python/core/PyDictionary.java @@ -259,7 +259,7 @@ PyType thisType = getType(); PyType otherType = otherObj.getType(); if (otherType != thisType && !thisType.isSubType(otherType) - && !otherType.isSubType(thisType)) { + && !otherType.isSubType(thisType) || otherType == PyObject.TYPE) { return null; } PyDictionary other = (PyDictionary)otherObj; @@ -344,7 +344,7 @@ PyType thisType = getType(); PyType otherType = otherObj.getType(); if (otherType != thisType && !thisType.isSubType(otherType) - && !otherType.isSubType(thisType)) { + && !otherType.isSubType(thisType) || otherType == PyObject.TYPE) { return -2; } PyDictionary other = (PyDictionary)otherObj; diff --git a/src/org/python/core/PySequence.java b/src/org/python/core/PySequence.java --- a/src/org/python/core/PySequence.java +++ b/src/org/python/core/PySequence.java @@ -156,7 +156,7 @@ } final PyObject seq___eq__(PyObject o) { - if (!isSubType(o)) { + if (!isSubType(o) || o.getType() == PyObject.TYPE) { return null; } int tl = __len__(); @@ -174,7 +174,7 @@ } final PyObject seq___ne__(PyObject o) { - if (!isSubType(o)) { + if (!isSubType(o) || o.getType() == PyObject.TYPE) { return null; } int tl = __len__(); @@ -192,7 +192,7 @@ } final PyObject seq___lt__(PyObject o) { - if (!isSubType(o)) { + if (!isSubType(o) || o.getType() == PyObject.TYPE) { return null; } int i = cmp(this, -1, o, -1); @@ -208,7 +208,7 @@ } final PyObject seq___le__(PyObject o) { - if (!isSubType(o)) { + if (!isSubType(o) || o.getType() == PyObject.TYPE) { return null; } int i = cmp(this, -1, o, -1); @@ -224,7 +224,7 @@ } final PyObject seq___gt__(PyObject o) { - if (!isSubType(o)) { + if (!isSubType(o) || o.getType() == PyObject.TYPE) { return null; } int i = cmp(this, -1, o, -1); @@ -240,7 +240,7 @@ } final PyObject seq___ge__(PyObject o) { - if (!isSubType(o)) { + if (!isSubType(o) || o.getType() == PyObject.TYPE) { return null; } int i = cmp(this, -1, o, -1); -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 9 08:03:42 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 09 Dec 2014 07:03:42 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Systematic_sweep_of_flaky_t?= =?utf-8?q?ests_in_test=5Fweakref?= Message-ID: <20141209070339.678.31103@psf.io> https://hg.python.org/jython/rev/ebb6d9049d15 changeset: 7443:ebb6d9049d15 user: Jim Baker date: Tue Dec 09 00:03:35 2014 -0700 summary: Systematic sweep of flaky tests in test_weakref Attempt to identify all use of weakref callbacks to ensure extra_collect() is called to ensure reaper thread can do its work. files: Lib/test/test_weakref.py | 10 +++++----- 1 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -638,7 +638,7 @@ del c1, c2, C # make them all trash self.assertEqual(alist, []) # del isn't enough to reclaim anything - gc.collect() + extra_collect() # c1.wr and c2.wr were part of the cyclic trash, so should have # been cleared without their callbacks executing. OTOH, the weakref # to C is bound to a function local (wr), and wasn't trash, so that @@ -682,7 +682,7 @@ del callback, c, d, C self.assertEqual(alist, []) # del isn't enough to clean up cycles - gc.collect() + extra_collect() self.assertEqual(alist, ["safe_callback called"]) self.assertEqual(external_wr(), None) @@ -755,12 +755,12 @@ weakref.ref(int) a = weakref.ref(A, l.append) A = None - gc.collect() + extra_collect() self.assertEqual(a(), None) self.assertEqual(l, [a]) b = weakref.ref(B, l.append) B = None - gc.collect() + extra_collect() self.assertEqual(b(), None) self.assertEqual(l, [a, b]) @@ -850,7 +850,7 @@ self.assertTrue(mr.called) self.assertEqual(mr.value, 24) del o - gc.collect() + extra_collect() self.assertTrue(mr() is None) self.assertTrue(mr.called) -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 9 23:34:25 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 09 Dec 2014 22:34:25 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Update_pkgutil=2Ewalk=5Fpac?= =?utf-8?q?kages_so_it_ignores_unreadable_dirs?= Message-ID: <20141209223421.118672.74559@psf.io> https://hg.python.org/jython/rev/1823003917fd changeset: 7444:1823003917fd user: Jim Baker date: Tue Dec 09 15:24:37 2014 -0700 summary: Update pkgutil.walk_packages so it ignores unreadable dirs files: Lib/pkgutil.py | 7 +- Lib/test/test_pkgutil.py | 142 --------------------------- 2 files changed, 6 insertions(+), 143 deletions(-) diff --git a/Lib/pkgutil.py b/Lib/pkgutil.py --- a/Lib/pkgutil.py +++ b/Lib/pkgutil.py @@ -214,7 +214,12 @@ if not modname and os.path.isdir(path) and '.' not in fn: modname = fn - for fn in os.listdir(path): + try: + dircontents = os.listdir(path) + except OSError: + # ignore unreadable directories like import does + dircontents = [] + for fn in dircontents: subname = inspect.getmodulename(fn) if subname=='__init__': ispkg = True diff --git a/Lib/test/test_pkgutil.py b/Lib/test/test_pkgutil.py deleted file mode 100644 --- a/Lib/test/test_pkgutil.py +++ /dev/null @@ -1,142 +0,0 @@ -from test.test_support import run_unittest, is_jython -import unittest -import sys -import imp -import pkgutil -import os -import os.path -import tempfile -import shutil -import zipfile - - - -class PkgutilTests(unittest.TestCase): - - def setUp(self): - self.dirname = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, self.dirname) - sys.path.insert(0, self.dirname) - - def tearDown(self): - del sys.path[0] - - def test_getdata_filesys(self): - pkg = 'test_getdata_filesys' - - # Include a LF and a CRLF, to test that binary data is read back - RESOURCE_DATA = 'Hello, world!\nSecond line\r\nThird line' - - # Make a package with some resources - package_dir = os.path.join(self.dirname, pkg) - os.mkdir(package_dir) - # Empty init.py - f = open(os.path.join(package_dir, '__init__.py'), "wb") - f.close() - # Resource files, res.txt, sub/res.txt - f = open(os.path.join(package_dir, 'res.txt'), "wb") - f.write(RESOURCE_DATA) - f.close() - os.mkdir(os.path.join(package_dir, 'sub')) - f = open(os.path.join(package_dir, 'sub', 'res.txt'), "wb") - f.write(RESOURCE_DATA) - f.close() - - # Check we can read the resources - res1 = pkgutil.get_data(pkg, 'res.txt') - self.assertEqual(res1, RESOURCE_DATA) - res2 = pkgutil.get_data(pkg, 'sub/res.txt') - self.assertEqual(res2, RESOURCE_DATA) - - del sys.modules[pkg] - - def test_getdata_zipfile(self): - zip = 'test_getdata_zipfile.zip' - pkg = 'test_getdata_zipfile' - - # Include a LF and a CRLF, to test that binary data is read back - RESOURCE_DATA = 'Hello, world!\nSecond line\r\nThird line' - - # Make a package with some resources - zip_file = os.path.join(self.dirname, zip) - z = zipfile.ZipFile(zip_file, 'w') - - # Empty init.py - z.writestr(pkg + '/__init__.py', "") - # Resource files, res.txt, sub/res.txt - z.writestr(pkg + '/res.txt', RESOURCE_DATA) - z.writestr(pkg + '/sub/res.txt', RESOURCE_DATA) - z.close() - - # Check we can read the resources - sys.path.insert(0, zip_file) - res1 = pkgutil.get_data(pkg, 'res.txt') - self.assertEqual(res1, RESOURCE_DATA) - res2 = pkgutil.get_data(pkg, 'sub/res.txt') - self.assertEqual(res2, RESOURCE_DATA) - del sys.path[0] - - del sys.modules[pkg] - - @unittest.skipIf(is_jython, "FIXME: not working on Jython") - def test_unreadable_dir_on_syspath(self): - # issue7367 - walk_packages failed if unreadable dir on sys.path - package_name = "unreadable_package" - d = os.path.join(self.dirname, package_name) - # this does not appear to create an unreadable dir on Windows - # but the test should not fail anyway - os.mkdir(d, 0) - self.addCleanup(os.rmdir, d) - for t in pkgutil.walk_packages(path=[self.dirname]): - self.fail("unexpected package found") - -class PkgutilPEP302Tests(unittest.TestCase): - - class MyTestLoader(object): - def load_module(self, fullname): - # Create an empty module - mod = sys.modules.setdefault(fullname, imp.new_module(fullname)) - mod.__file__ = "<%s>" % self.__class__.__name__ - mod.__loader__ = self - # Make it a package - mod.__path__ = [] - # Count how many times the module is reloaded - mod.__dict__['loads'] = mod.__dict__.get('loads',0) + 1 - return mod - - def get_data(self, path): - return "Hello, world!" - - class MyTestImporter(object): - def find_module(self, fullname, path=None): - return PkgutilPEP302Tests.MyTestLoader() - - def setUp(self): - sys.meta_path.insert(0, self.MyTestImporter()) - - def tearDown(self): - del sys.meta_path[0] - - def test_getdata_pep302(self): - # Use a dummy importer/loader - self.assertEqual(pkgutil.get_data('foo', 'dummy'), "Hello, world!") - del sys.modules['foo'] - - def test_alreadyloaded(self): - # Ensure that get_data works without reloading - the "loads" module - # variable in the example loader should count how many times a reload - # occurs. - import foo - self.assertEqual(foo.loads, 1) - self.assertEqual(pkgutil.get_data('foo', 'dummy'), "Hello, world!") - self.assertEqual(foo.loads, 1) - del sys.modules['foo'] - -def test_main(): - run_unittest(PkgutilTests, PkgutilPEP302Tests) - # this is necessary if test is run repeated (like when finding leaks) - import zipimport - zipimport._zip_directory_cache.clear() - -if __name__ == '__main__': - test_main() -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sat Dec 13 02:58:43 2014 From: jython-checkins at python.org (jim.baker) Date: Sat, 13 Dec 2014 01:58:43 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Imported_modules_now_have_?= =?utf-8?b?X19idWlsdGluc19fIHNldC4=?= Message-ID: <20141213015837.81001.32928@psf.io> https://hg.python.org/jython/rev/c33128fdf05f changeset: 7445:c33128fdf05f user: Jim Baker date: Fri Dec 12 18:58:32 2014 -0700 summary: Imported modules now have __builtins__ set. In the past Jython did not support __builtins__ as an implementation detail of CPython. See for example https://docs.python.org/2/library/__builtin__.html and http://bugs.jython.org/issue1890. However, we might as well support it, since doing so further minimizes compatibility differences. Also updated test_builtin and test_module to latest version from 2.7, then applied a minimal set of Jython-specific changes. Removed Jython-specific test_pdb files: Lib/test/test_builtin.py | 316 ++++++++++---- Lib/test/test_codeop_jy.py | 1 + Lib/test/test_descr.py | 2 +- Lib/test/test_module.py | 142 +++--- Lib/test/test_pdb.py | 316 --------------- Lib/test/test_support.py | 16 + src/org/python/core/Py.java | 13 + src/org/python/core/PyModule.java | 24 +- src/org/python/core/PyType.java | 4 +- src/org/python/core/__builtin__.java | 70 +-- src/org/python/core/imp.java | 4 + 11 files changed, 371 insertions(+), 537 deletions(-) diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -2,13 +2,13 @@ import platform import unittest +from test.test_support import fcmp, have_unicode, TESTFN, unlink, \ + run_unittest, check_py3k_warnings, is_jython import warnings -from test.test_support import (fcmp, have_unicode, TESTFN, unlink, - run_unittest, check_py3k_warnings, check_warnings, - is_jython) from operator import neg import sys, cStringIO, random, UserDict + # count the number of test runs. # used to skip running test_execfile() multiple times # and to create unique strings to intern in test_intern() @@ -90,6 +90,16 @@ self.assertEqual(abs(-1234L), 1234L) # str self.assertRaises(TypeError, abs, 'a') + # bool + self.assertEqual(abs(True), 1) + self.assertEqual(abs(False), 0) + # other + self.assertRaises(TypeError, abs) + self.assertRaises(TypeError, abs, None) + class AbsClass(object): + def __abs__(self): + return -5 + self.assertEqual(abs(AbsClass()), -5) def test_all(self): self.assertEqual(all([2, 4, 6]), True) @@ -100,6 +110,7 @@ self.assertRaises(TypeError, all) # No args self.assertRaises(TypeError, all, [2, 4, 6], []) # Too many args self.assertEqual(all([]), True) # Empty iterator + self.assertEqual(all([0, TestFailingBool()]), False)# Short-circuit S = [50, 60] self.assertEqual(all(x > 42 for x in S), True) S = [50, 40, 60] @@ -109,11 +120,12 @@ self.assertEqual(any([None, None, None]), False) self.assertEqual(any([None, 4, None]), True) self.assertRaises(RuntimeError, any, [None, TestFailingBool(), 6]) - self.assertRaises(RuntimeError, all, TestFailingIter()) + self.assertRaises(RuntimeError, any, TestFailingIter()) self.assertRaises(TypeError, any, 10) # Non-iterable self.assertRaises(TypeError, any) # No args self.assertRaises(TypeError, any, [2, 4, 6], []) # Too many args self.assertEqual(any([]), False) # Empty iterator + self.assertEqual(any([1, TestFailingBool()]), True) # Short-circuit S = [40, 60, 30] self.assertEqual(any(x > 42 for x in S), True) S = [10, 20, 30] @@ -121,7 +133,7 @@ def test_neg(self): x = -sys.maxint-1 - self.assert_(isinstance(x, int)) + self.assertTrue(isinstance(x, int)) self.assertEqual(-x, sys.maxint+1) def test_apply(self): @@ -151,20 +163,45 @@ self.assertRaises(TypeError, apply, id, (42,), 42) def test_callable(self): - self.assert_(callable(len)) + self.assertTrue(callable(len)) + self.assertFalse(callable("a")) + self.assertTrue(callable(callable)) + self.assertTrue(callable(lambda x, y: x + y)) + self.assertFalse(callable(__builtins__)) def f(): pass - self.assert_(callable(f)) - class C: + self.assertTrue(callable(f)) + + class Classic: def meth(self): pass - self.assert_(callable(C)) - x = C() - self.assert_(callable(x.meth)) - self.assert_(not callable(x)) - class D(C): + self.assertTrue(callable(Classic)) + c = Classic() + self.assertTrue(callable(c.meth)) + self.assertFalse(callable(c)) + + class NewStyle(object): + def meth(self): pass + self.assertTrue(callable(NewStyle)) + n = NewStyle() + self.assertTrue(callable(n.meth)) + self.assertFalse(callable(n)) + + # Classic and new-style classes evaluate __call__() differently + c.__call__ = None + self.assertTrue(callable(c)) + del c.__call__ + self.assertFalse(callable(c)) + n.__call__ = None + self.assertFalse(callable(n)) + del n.__call__ + self.assertFalse(callable(n)) + + class N2(object): def __call__(self): pass - y = D() - self.assert_(callable(y)) - y() + n2 = N2() + self.assertTrue(callable(n2)) + class N3(N2): pass + n3 = N3() + self.assertTrue(callable(n3)) def test_chr(self): self.assertEqual(chr(32), ' ') @@ -178,23 +215,29 @@ self.assertEqual(cmp(-1, 1), -1) self.assertEqual(cmp(1, -1), 1) self.assertEqual(cmp(1, 1), 0) - # verify that circular objects are handled for Jython + # verify that circular objects are not handled a = []; a.append(a) b = []; b.append(b) from UserList import UserList c = UserList(); c.append(c) - self.assertEqual(cmp(a, b), 0) - self.assertEqual(cmp(b, c), 0) - self.assertEqual(cmp(c, a), 0) - self.assertEqual(cmp(a, c), 0) - # okay, now break the cycles + if is_jython: + self.assertEqual(cmp(a, b), 0) + self.assertEqual(cmp(b, c), 0) + self.assertEqual(cmp(c, a), 0) + self.assertEqual(cmp(a, c), 0) + else: + self.assertRaises(RuntimeError, cmp, a, b) + self.assertRaises(RuntimeError, cmp, b, c) + self.assertRaises(RuntimeError, cmp, c, a) + self.assertRaises(RuntimeError, cmp, a, c) + # okay, now break the cycles a.pop(); b.pop(); c.pop() self.assertRaises(TypeError, cmp) def test_coerce(self): - self.assert_(not fcmp(coerce(1, 1.1), (1.0, 1.1))) + self.assertTrue(not fcmp(coerce(1, 1.1), (1.0, 1.1))) self.assertEqual(coerce(1, 1L), (1L, 1L)) - self.assert_(not fcmp(coerce(1L, 1.1), (1.0, 1.1))) + self.assertTrue(not fcmp(coerce(1L, 1.1), (1.0, 1.1))) self.assertRaises(TypeError, coerce) class BadNumber: def __coerce__(self, other): @@ -233,23 +276,22 @@ # dir() - local scope local_var = 1 - self.assert_('local_var' in dir()) + self.assertIn('local_var', dir()) # dir(module) import sys - self.assert_('exit' in dir(sys)) + self.assertIn('exit', dir(sys)) # dir(module_with_invalid__dict__) import types class Foo(types.ModuleType): __dict__ = 8 f = Foo("foo") - if not is_jython: #FIXME #1861 - self.assertRaises(TypeError, dir, f) + self.assertRaises(TypeError, dir, f) # dir(type) - self.assert_("strip" in dir(str)) - self.assert_("__mro__" not in dir(str)) + self.assertIn("strip", dir(str)) + self.assertNotIn("__mro__", dir(str)) # dir(obj) class Foo(object): @@ -258,13 +300,13 @@ self.y = 8 self.z = 9 f = Foo() - self.assert_("y" in dir(f)) + self.assertIn("y", dir(f)) # dir(obj_no__dict__) class Foo(object): __slots__ = [] f = Foo() - self.assert_("__repr__" in dir(f)) + self.assertIn("__repr__", dir(f)) # dir(obj_no__class__with__dict__) # (an ugly trick to cause getattr(f, "__class__") to fail) @@ -273,24 +315,22 @@ def __init__(self): self.bar = "wow" f = Foo() - self.assert_("__repr__" not in dir(f)) - self.assert_("bar" in dir(f)) + self.assertNotIn("__repr__", dir(f)) + self.assertIn("bar", dir(f)) # dir(obj_using __dir__) class Foo(object): def __dir__(self): return ["kan", "ga", "roo"] f = Foo() - if not is_jython: #FIXME #1861 - self.assert_(dir(f) == ["ga", "kan", "roo"]) + self.assertTrue(dir(f) == ["ga", "kan", "roo"]) # dir(obj__dir__not_list) class Foo(object): def __dir__(self): return 7 f = Foo() - if not is_jython: #FIXME #1861 - self.assertRaises(TypeError, dir, f) + self.assertRaises(TypeError, dir, f) def test_divmod(self): self.assertEqual(divmod(12, 7), (1, 5)) @@ -311,10 +351,10 @@ self.assertEqual(divmod(-sys.maxint-1, -1), (sys.maxint+1, 0)) - self.assert_(not fcmp(divmod(3.25, 1.0), (3.0, 0.25))) - self.assert_(not fcmp(divmod(-3.25, 1.0), (-4.0, 0.75))) - self.assert_(not fcmp(divmod(3.25, -1.0), (-4.0, -0.75))) - self.assert_(not fcmp(divmod(-3.25, -1.0), (3.0, -0.25))) + self.assertTrue(not fcmp(divmod(3.25, 1.0), (3.0, 0.25))) + self.assertTrue(not fcmp(divmod(-3.25, 1.0), (-4.0, 0.75))) + self.assertTrue(not fcmp(divmod(3.25, -1.0), (-4.0, -0.75))) + self.assertTrue(not fcmp(divmod(-3.25, -1.0), (3.0, -0.25))) self.assertRaises(TypeError, divmod) @@ -363,9 +403,12 @@ self.assertEqual(eval('dir()', g, m), list('xyz')) self.assertEqual(eval('globals()', g, m), g) self.assertEqual(eval('locals()', g, m), m) - - # Jython allows arbitrary mappings for globals - self.assertEqual(eval('a', m), 12) + if is_jython: + # Jython allows any mapping to work, including ones that + # are read only as in the case of M + self.assertEqual(eval('a', m), 12) + else: + self.assertRaises(TypeError, eval, 'a', m) class A: "Non-mapping" pass @@ -577,11 +620,11 @@ for func in funcs: outp = filter(func, cls(inp)) self.assertEqual(outp, exp) - self.assert_(not isinstance(outp, cls)) + self.assertTrue(not isinstance(outp, cls)) def test_getattr(self): import sys - self.assert_(getattr(sys, 'stdout') is sys.stdout) + self.assertTrue(getattr(sys, 'stdout') is sys.stdout) self.assertRaises(TypeError, getattr, sys, 1) self.assertRaises(TypeError, getattr, sys, 1, "foo") self.assertRaises(TypeError, getattr) @@ -590,7 +633,7 @@ def test_hasattr(self): import sys - self.assert_(hasattr(sys, 'stdout')) + self.assertTrue(hasattr(sys, 'stdout')) self.assertRaises(TypeError, hasattr, sys, 1) self.assertRaises(TypeError, hasattr) if have_unicode: @@ -621,15 +664,15 @@ class X: def __hash__(self): return 2**100 - self.assertEquals(type(hash(X())), int) + self.assertEqual(type(hash(X())), int) class Y(object): def __hash__(self): return 2**100 - self.assertEquals(type(hash(Y())), int) + self.assertEqual(type(hash(Y())), int) class Z(long): def __hash__(self): return self - self.assertEquals(hash(Z(42)), hash(42L)) + self.assertEqual(hash(Z(42)), hash(42L)) def test_hex(self): self.assertEqual(hex(16), '0x10') @@ -650,20 +693,22 @@ # Test input() later, together with raw_input + # test_int(): see test_int.py for int() tests. + def test_intern(self): self.assertRaises(TypeError, intern) # This fails if the test is run twice with a constant string, # therefore append the run counter s = "never interned before " + str(numruns) - self.assert_(intern(s) is s) + self.assertTrue(intern(s) is s) s2 = s.swapcase().swapcase() - self.assert_(intern(s2) is s) + self.assertTrue(intern(s2) is s) # Subclasses of string can't be interned, because they # provide too much opportunity for insane things to happen. # We don't want them in the interned dict and if they aren't # actually interned, we don't want to create the appearance - # that they are by allowing intern() to succeeed. + # that they are by allowing intern() to succeed. class S(str): def __hash__(self): return 123 @@ -698,11 +743,11 @@ c = C() d = D() e = E() - self.assert_(isinstance(c, C)) - self.assert_(isinstance(d, C)) - self.assert_(not isinstance(e, C)) - self.assert_(not isinstance(c, D)) - self.assert_(not isinstance('foo', E)) + self.assertTrue(isinstance(c, C)) + self.assertTrue(isinstance(d, C)) + self.assertTrue(not isinstance(e, C)) + self.assertTrue(not isinstance(c, D)) + self.assertTrue(not isinstance('foo', E)) self.assertRaises(TypeError, isinstance, E, 'foo') self.assertRaises(TypeError, isinstance) @@ -716,9 +761,9 @@ c = C() d = D() e = E() - self.assert_(issubclass(D, C)) - self.assert_(issubclass(C, C)) - self.assert_(not issubclass(C, D)) + self.assertTrue(issubclass(D, C)) + self.assertTrue(issubclass(C, C)) + self.assertTrue(not issubclass(C, D)) self.assertRaises(TypeError, issubclass, 'foo', E) self.assertRaises(TypeError, issubclass, E, 'foo') self.assertRaises(TypeError, issubclass) @@ -734,6 +779,11 @@ def __len__(self): raise ValueError self.assertRaises(ValueError, len, BadSeq()) + self.assertRaises(TypeError, len, 2) + class ClassicStyle: pass + class NewStyle(object): pass + self.assertRaises(AttributeError, len, ClassicStyle()) + self.assertRaises(TypeError, len, NewStyle()) def test_map(self): self.assertEqual( @@ -895,7 +945,7 @@ self.assertEqual(next(it), 1) self.assertRaises(StopIteration, next, it) self.assertRaises(StopIteration, next, it) - self.assertEquals(next(it, 42), 42) + self.assertEqual(next(it, 42), 42) class Iter(object): def __iter__(self): @@ -904,7 +954,7 @@ raise StopIteration it = iter(Iter()) - self.assertEquals(next(it, 42), 42) + self.assertEqual(next(it, 42), 42) self.assertRaises(StopIteration, next, it) def gen(): @@ -912,9 +962,9 @@ return it = gen() - self.assertEquals(next(it), 1) + self.assertEqual(next(it), 1) self.assertRaises(StopIteration, next, it) - self.assertEquals(next(it, 42), 42) + self.assertEqual(next(it, 42), 42) def test_oct(self): self.assertEqual(oct(100), '0144') @@ -1050,18 +1100,18 @@ self.assertEqual(range(a+4, a, -2), [a+4, a+2]) seq = range(a, b, c) - self.assert_(a in seq) - self.assert_(b not in seq) + self.assertIn(a, seq) + self.assertNotIn(b, seq) self.assertEqual(len(seq), 2) seq = range(b, a, -c) - self.assert_(b in seq) - self.assert_(a not in seq) + self.assertIn(b, seq) + self.assertNotIn(a, seq) self.assertEqual(len(seq), 2) seq = range(-a, -b, -c) - self.assert_(-a in seq) - self.assert_(-b not in seq) + self.assertIn(-a, seq) + self.assertNotIn(-b, seq) self.assertEqual(len(seq), 2) self.assertRaises(TypeError, range) @@ -1075,14 +1125,9 @@ __hash__ = None # Invalid cmp makes this unhashable self.assertRaises(RuntimeError, range, a, a + 1, badzero(1)) - # Reject floats when it would require PyLongs to represent. - # (smaller floats still accepted, but deprecated) - with check_warnings() as w: - warnings.simplefilter("always") - self.assertRaises(TypeError, range, 1e100, 1e101, 1e101) - with check_warnings() as w: - warnings.simplefilter("always") - self.assertEqual(range(1.0), [0]) + # Reject floats. + self.assertRaises(TypeError, range, 1., 1., 1.) + self.assertRaises(TypeError, range, 1e100, 1e101, 1e101) self.assertRaises(TypeError, range, 0, "spam") self.assertRaises(TypeError, range, 0, 42, "spam") @@ -1124,20 +1169,21 @@ # Exercise various combinations of bad arguments, to check # refcounting logic - with check_warnings(): - self.assertRaises(TypeError, range, 1e100) + self.assertRaises(TypeError, range, 0.0) - self.assertRaises(TypeError, range, 0, 1e100) - self.assertRaises(TypeError, range, 1e100, 0) - self.assertRaises(TypeError, range, 1e100, 1e100) + self.assertRaises(TypeError, range, 0, 0.0) + self.assertRaises(TypeError, range, 0.0, 0) + self.assertRaises(TypeError, range, 0.0, 0.0) - self.assertRaises(TypeError, range, 0, 0, 1e100) - self.assertRaises(TypeError, range, 0, 1e100, 1) - self.assertRaises(TypeError, range, 0, 1e100, 1e100) - self.assertRaises(TypeError, range, 1e100, 0, 1) - self.assertRaises(TypeError, range, 1e100, 0, 1e100) - self.assertRaises(TypeError, range, 1e100, 1e100, 1) - self.assertRaises(TypeError, range, 1e100, 1e100, 1e100) + self.assertRaises(TypeError, range, 0, 0, 1.0) + self.assertRaises(TypeError, range, 0, 0.0, 1) + self.assertRaises(TypeError, range, 0, 0.0, 1.0) + self.assertRaises(TypeError, range, 0.0, 0, 1) + self.assertRaises(TypeError, range, 0.0, 0, 1.0) + self.assertRaises(TypeError, range, 0.0, 0.0, 1) + self.assertRaises(TypeError, range, 0.0, 0.0, 1.0) + + def test_input_and_raw_input(self): self.write_testfile() @@ -1197,9 +1243,10 @@ unlink(TESTFN) def test_reduce(self): - self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc') + add = lambda x, y: x+y + self.assertEqual(reduce(add, ['a', 'b', 'c'], ''), 'abc') self.assertEqual( - reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []), + reduce(add, [['a', 'c'], [], ['d', 'w']], []), ['a','c','d','w'] ) self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040) @@ -1207,15 +1254,23 @@ reduce(lambda x, y: x*y, range(2,21), 1L), 2432902008176640000L ) - self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285) - self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285) - self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0) + self.assertEqual(reduce(add, Squares(10)), 285) + self.assertEqual(reduce(add, Squares(10), 0), 285) + self.assertEqual(reduce(add, Squares(0), 0), 0) self.assertRaises(TypeError, reduce) + self.assertRaises(TypeError, reduce, 42) self.assertRaises(TypeError, reduce, 42, 42) self.assertRaises(TypeError, reduce, 42, 42, 42) + self.assertRaises(TypeError, reduce, None, range(5)) + self.assertRaises(TypeError, reduce, add, 42) self.assertEqual(reduce(42, "1"), "1") # func is never called with one item self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item self.assertRaises(TypeError, reduce, 42, (42, 42)) + self.assertRaises(TypeError, reduce, add, []) # arg 2 must not be empty sequence with no initial value + self.assertRaises(TypeError, reduce, add, "") + self.assertRaises(TypeError, reduce, add, ()) + self.assertEqual(reduce(add, [], None), None) + self.assertEqual(reduce(add, [], 42), 42) class BadSeq: def __getitem__(self, index): @@ -1318,6 +1373,19 @@ self.assertRaises(TypeError, round, t) self.assertRaises(TypeError, round, t, 0) + # Some versions of glibc for alpha have a bug that affects + # float -> integer rounding (floor, ceil, rint, round) for + # values in the range [2**52, 2**53). See: + # + # http://sources.redhat.com/bugzilla/show_bug.cgi?id=5350 + # + # We skip this test on Linux/alpha if it would fail. + linux_alpha = (platform.system().startswith('Linux') and + platform.machine().startswith('alpha')) + system_round_bug = round(5e15+1) != 5e15+1 + @unittest.skipIf(linux_alpha and system_round_bug, + "test will fail; failure is probably due to a " + "buggy system round function") def test_round_large(self): # Issue #1869: integral floats should remain unchanged self.assertEqual(round(5e15-1), 5e15-1) @@ -1353,6 +1421,10 @@ raise ValueError self.assertRaises(ValueError, sum, BadSeq()) + empty = [] + sum(([x] for x in range(10)), empty) + self.assertEqual(empty, []) + def test_type(self): self.assertEqual(type(''), type('123')) self.assertNotEqual(type(''), type(())) @@ -1368,8 +1440,7 @@ ) self.assertRaises(ValueError, unichr, sys.maxunicode+1) self.assertRaises(TypeError, unichr) - if not is_jython: #FIXME #1861 - self.assertRaises((OverflowError, ValueError), unichr, 2**32) + self.assertRaises((OverflowError, ValueError), unichr, 2**32) # We don't want self in vars(), so these are static methods @@ -1384,6 +1455,11 @@ b = 2 return vars() + class C_get_vars(object): + def getDict(self): + return {'a':2} + __dict__ = property(fget=getDict) + def test_vars(self): self.assertEqual(set(vars()), set(dir())) import sys @@ -1392,6 +1468,7 @@ self.assertEqual(self.get_vars_f2(), {'a': 1, 'b': 2}) self.assertRaises(TypeError, vars, 42, 42) self.assertRaises(TypeError, vars, 42) + self.assertEqual(vars(self.C_get_vars()), {'a':2}) def test_zip(self): a = (1, 2, 3) @@ -1511,8 +1588,7 @@ class BadFormatResult: def __format__(self, format_spec): return 1.0 - if not is_jython: #FIXME #1861 check again when __format__ works better. - self.assertRaises(TypeError, format, BadFormatResult(), "") + self.assertRaises(TypeError, format, BadFormatResult(), "") # TypeError because format_spec is not unicode or str self.assertRaises(TypeError, format, object(), 4) @@ -1521,13 +1597,48 @@ # tests for object.__format__ really belong elsewhere, but # there's no good place to put them x = object().__format__('') - self.assert_(x.startswith('>> def test_function(foo, bar): - ... import pdb; pdb.Pdb().set_trace() - ... pass - - >>> with PdbTestInput([ - ... 'foo', - ... 'bar', - ... 'for i in range(5): write(i)', - ... 'continue', - ... ]): - ... test_function(1, None) - > (3)test_function() - -> pass - (Pdb) foo - 1 - (Pdb) bar - (Pdb) for i in range(5): write(i) - 0 - 1 - 2 - 3 - 4 - (Pdb) continue - """ - -def test_pdb_breakpoint_commands(): - """Test basic commands related to breakpoints. - - >>> def test_function(): - ... import pdb; pdb.Pdb().set_trace() - ... print(1) - ... print(2) - ... print(3) - ... print(4) - - First, need to clear bdb state that might be left over from previous tests. - Otherwise, the new breakpoints might get assigned different numbers. - - >>> from bdb import Breakpoint - >>> Breakpoint.next = 1 - >>> Breakpoint.bplist = {} - >>> Breakpoint.bpbynumber = [None] - - Now test the breakpoint commands. NORMALIZE_WHITESPACE is needed because - the breakpoint list outputs a tab for the "stop only" and "ignore next" - lines, which we don't want to put in here. - - >>> with PdbTestInput([ # doctest: +NORMALIZE_WHITESPACE - ... 'break 3', - ... 'disable 1', - ... 'ignore 1 10', - ... 'condition 1 1 < 2', - ... 'break 4', - ... 'break 4', - ... 'break', - ... 'clear 3', - ... 'break', - ... 'condition 1', - ... 'enable 1', - ... 'clear 1', - ... 'commands 2', - ... 'print 42', - ... 'end', - ... 'continue', # will stop at breakpoint 2 (line 4) - ... 'clear', # clear all! - ... 'y', - ... 'tbreak 5', - ... 'continue', # will stop at temporary breakpoint - ... 'break', # make sure breakpoint is gone - ... 'continue', - ... ]): - ... test_function() - > (3)test_function() - -> print(1) - (Pdb) break 3 - Breakpoint 1 at :3 - (Pdb) disable 1 - (Pdb) ignore 1 10 - Will ignore next 10 crossings of breakpoint 1. - (Pdb) condition 1 1 < 2 - (Pdb) break 4 - Breakpoint 2 at :4 - (Pdb) break 4 - Breakpoint 3 at :4 - (Pdb) break - Num Type Disp Enb Where - 1 breakpoint keep no at :3 - stop only if 1 < 2 - ignore next 10 hits - 2 breakpoint keep yes at :4 - 3 breakpoint keep yes at :4 - (Pdb) clear 3 - Deleted breakpoint 3 - (Pdb) break - Num Type Disp Enb Where - 1 breakpoint keep no at :3 - stop only if 1 < 2 - ignore next 10 hits - 2 breakpoint keep yes at :4 - (Pdb) condition 1 - Breakpoint 1 is now unconditional. - (Pdb) enable 1 - (Pdb) clear 1 - Deleted breakpoint 1 - (Pdb) commands 2 - (com) print 42 - (com) end - (Pdb) continue - 1 - 42 - > (4)test_function() - -> print(2) - (Pdb) clear - Clear all breaks? y - (Pdb) tbreak 5 - Breakpoint 4 at :5 - (Pdb) continue - 2 - Deleted breakpoint 4 - > (5)test_function() - -> print(3) - (Pdb) break - (Pdb) continue - 3 - 4 - """ - - -def test_pdb_skip_modules(): - """This illustrates the simple case of module skipping. - - >>> def skip_module(): - ... import string - ... import pdb; pdb.Pdb(skip=['string*']).set_trace() - ... string.lower('FOO') - - >>> with PdbTestInput([ - ... 'step', - ... 'continue', - ... ]): - ... skip_module() - > (4)skip_module() - -> string.lower('FOO') - (Pdb) step - --Return-- - > (4)skip_module()->None - -> string.lower('FOO') - (Pdb) continue - """ - - -# Module for testing skipping of module that makes a callback -mod = imp.new_module('module_to_skip') -exec 'def foo_pony(callback): x = 1; callback(); return None' in mod.__dict__ - - -def test_pdb_skip_modules_with_callback(): - """This illustrates skipping of modules that call into other code. - - >>> def skip_module(): - ... def callback(): - ... return None - ... import pdb; pdb.Pdb(skip=['module_to_skip*']).set_trace() - ... mod.foo_pony(callback) - - >>> with PdbTestInput([ - ... 'step', - ... 'step', - ... 'step', - ... 'step', - ... 'step', - ... 'continue', - ... ]): - ... skip_module() - ... pass # provides something to "step" to - > (5)skip_module() - -> mod.foo_pony(callback) - (Pdb) step - --Call-- - > (2)callback() - -> def callback(): - (Pdb) step - > (3)callback() - -> return None - (Pdb) step - --Return-- - > (3)callback()->None - -> return None - (Pdb) step - --Return-- - > (5)skip_module()->None - -> mod.foo_pony(callback) - (Pdb) step - > (10)() - -> pass # provides something to "step" to - (Pdb) continue - """ - - -def test_pdb_continue_in_bottomframe(): - """Test that "continue" and "next" work properly in bottom frame (issue #5294). - - >>> def test_function(): - ... import pdb, sys; inst = pdb.Pdb() - ... inst.set_trace() - ... inst.botframe = sys._getframe() # hackery to get the right botframe - ... print(1) - ... print(2) - ... print(3) - ... print(4) - - First, need to clear bdb state that might be left over from previous tests. - Otherwise, the new breakpoints might get assigned different numbers. - - >>> from bdb import Breakpoint - >>> Breakpoint.next = 1 - >>> Breakpoint.bplist = {} - >>> Breakpoint.bpbynumber = [None] - - >>> with PdbTestInput([ - ... 'next', - ... 'break 7', - ... 'continue', - ... 'next', - ... 'continue', - ... 'continue', - ... ]): - ... test_function() - > (4)test_function() - -> inst.botframe = sys._getframe() # hackery to get the right botframe - (Pdb) next - > (5)test_function() - -> print(1) - (Pdb) break 7 - Breakpoint 1 at :7 - (Pdb) continue - 1 - 2 - > (7)test_function() - -> print(3) - (Pdb) next - 3 - > (8)test_function() - -> print(4) - (Pdb) continue - 4 - """ - -class ModuleInitTester(unittest.TestCase): - - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") - def test_filename_correct(self): - """ - In issue 7750, it was found that if the filename has a sequence that - resolves to an escape character in a Python string (such as \t), it - will be treated as the escaped character. - """ - # the test_fn must contain something like \t - # on Windows, this will create 'test_mod.py' in the current directory. - # on Unix, this will create '.\test_mod.py' in the current directory. - test_fn = '.\\test_mod.py' - code = 'print("testing pdb")' - with open(test_fn, 'w') as f: - f.write(code) - self.addCleanup(os.remove, test_fn) - cmd = [sys.executable, '-m', 'pdb', test_fn,] - proc = subprocess.Popen(cmd, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - stdout, stderr = proc.communicate('quit\n') - self.assertIn(code, stdout, "pdb munged the filename") - - -def test_main(): - from test import test_pdb - test_support.run_doctest(test_pdb, verbosity=True) - test_support.run_unittest(ModuleInitTester) - -if __name__ == '__main__': - test_main() diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -969,6 +969,22 @@ def captured_stdin(): return captured_output("stdin") +def gc_collect(): + """Force as many objects as possible to be collected. + + In non-CPython implementations of Python, this is needed because timely + deallocation is not guaranteed by the garbage collector. (Even in CPython + this can be the case in case of reference cycles.) This means that __del__ + methods may be called later than expected and weakrefs may remain alive for + longer than expected. This function tries its best to force all garbage + objects to disappear. + """ + gc.collect() + if is_jython: + time.sleep(0.1) + gc.collect() + gc.collect() + _header = '2P' if hasattr(sys, "gettotalrefcount"): diff --git a/src/org/python/core/Py.java b/src/org/python/core/Py.java --- a/src/org/python/core/Py.java +++ b/src/org/python/core/Py.java @@ -1316,6 +1316,19 @@ if (globals == null || globals == Py.None) { globals = ts.frame.f_globals; + } else if (globals.__finditem__("__builtins__") == null) { + // Apply side effect of copying into globals, + // per documentation of eval and observed behavior of exec + try { + globals.__setitem__("__builtins__", Py.getSystemState().modules.__finditem__("__builtin__").__getattr__("__dict__")); + } catch (PyException e) { + // Quietly ignore if cannot set __builtins__ - Jython previously allowed a much wider range of + // mappable objects for the globals mapping than CPython, do not want to break existing code + // as we try to get better CPython compliance + if (!e.match(AttributeError)) { + throw e; + } + } } PyBaseCode baseCode = null; diff --git a/src/org/python/core/PyModule.java b/src/org/python/core/PyModule.java --- a/src/org/python/core/PyModule.java +++ b/src/org/python/core/PyModule.java @@ -60,7 +60,9 @@ ensureDict(); __dict__.__setitem__("__name__", name); __dict__.__setitem__("__doc__", doc); - __dict__.__setitem__("__package__", Py.None); + if (name.equals(new PyString("__main__"))) { + __dict__.__setitem__("__builtins__", Py.getSystemState().modules.__finditem__("__builtin__")); + } } public PyObject fastGetDict() { @@ -166,10 +168,24 @@ } public PyObject __dir__() { - if (__dict__ == null) { - throw Py.TypeError("module.__dict__ is not a dictionary"); + // Some special casing to ensure that classes deriving from PyModule + // can use their own __dict__. Although it would be nice to do this in + // PyModuleDerived, current templating in gderived.py does not support + // including from object, then overriding a specific method. + PyObject d; + if (this instanceof PyModuleDerived) { + d = __findattr_ex__("__dict__"); + } else { + d = __dict__; } - return __dict__.invoke("keys"); + if (d == null || + !(d instanceof PyDictionary || + d instanceof PyStringMap || + d instanceof PyDictProxy)) { + throw Py.TypeError(String.format("%.200s.__dict__ is not a dictionary", + getType().fastGetName().toLowerCase())); + } + return d.invoke("keys"); } private void ensureDict() { diff --git a/src/org/python/core/PyType.java b/src/org/python/core/PyType.java --- a/src/org/python/core/PyType.java +++ b/src/org/python/core/PyType.java @@ -203,6 +203,8 @@ type.bases = tmpBases.length == 0 ? new PyObject[] {PyObject.TYPE} : tmpBases; type.dict = dict; type.tp_flags = Py.TPFLAGS_HEAPTYPE | Py.TPFLAGS_BASETYPE; + // Enable defining a custom __dict__ via a property, method, or other descriptor + boolean defines_dict = dict.__finditem__("__dict__") != null; // immediately setup the javaProxy if applicable. may modify bases List> interfaces = Generic.list(); @@ -215,7 +217,7 @@ base.name)); } - type.createAllSlots(!base.needs_userdict, !base.needs_weakref); + type.createAllSlots(!(base.needs_userdict || defines_dict), !base.needs_weakref); type.ensureAttributes(); type.invalidateMethodCache(); diff --git a/src/org/python/core/__builtin__.java b/src/org/python/core/__builtin__.java --- a/src/org/python/core/__builtin__.java +++ b/src/org/python/core/__builtin__.java @@ -14,8 +14,6 @@ import org.python.antlr.base.mod; import org.python.core.stringlib.IntegerFormatter; -import org.python.core.stringlib.InternalFormat; -import org.python.core.stringlib.InternalFormat.Spec; import org.python.core.util.ExtraMath; import org.python.core.util.RelativeFile; import org.python.modules._functools._functools; @@ -70,8 +68,7 @@ case 5: return __builtin__.hash(arg1); case 6: - return Py.newUnicode(__builtin__.unichr(Py.py2int(arg1, "unichr(): 1st arg can't " - + "be coerced to int"))); + return Py.newUnicode(__builtin__.unichr(arg1)); case 7: return __builtin__.abs(arg1); case 9: @@ -401,6 +398,16 @@ return obj.isCallable(); } + public static int unichr(PyObject obj) { + long l = obj.asLong(); + if (l < PySystemState.minint) { + throw Py.OverflowError("signed integer is less than minimum"); + } else if (l > PySystemState.maxint) { + throw Py.OverflowError("signed integer is greater than maximum"); + } + return unichr((int)l); + } + public static int unichr(int i) { if (i < 0 || i > PySystemState.maxunicode) { throw Py.ValueError("unichr() arg not in range(0x110000)"); @@ -435,8 +442,11 @@ } public static PyObject dir(PyObject o) { - PyList ret = (PyList) o.__dir__(); - ret.sort(); + PyObject ret = o.__dir__(); + if (!Py.isInstance(ret, PyList.TYPE)) { + throw Py.TypeError("__dir__() must return a list, not " + ret.getType().fastGetName()); + } + ((PyList)ret).sort(); return ret; } @@ -884,39 +894,6 @@ y.getType().fastGetName(), z.getType().fastGetName())); } - public static PyObject range(PyObject start, PyObject stop, PyObject step) { - int ilow = 0; - int ihigh = 0; - int istep = 1; - int n; - - try { - ilow = start.asInt(); - ihigh = stop.asInt(); - istep = step.asInt(); - } catch (PyException pye) { - return handleRangeLongs(start, stop, step); - } - - if (istep == 0) { - throw Py.ValueError("range() step argument must not be zero"); - } - if (istep > 0) { - n = PyXRange.getLenOfRange(ilow, ihigh, istep); - } else { - n = PyXRange.getLenOfRange(ihigh, ilow, -istep); - } - if (n < 0) { - throw Py.OverflowError("range() result has too many items"); - } - - PyObject[] range = new PyObject[n]; - for (int i = 0; i < n; i++, ilow += istep) { - range[i] = Py.newInteger(ilow); - } - return new PyList(range); - } - public static PyObject range(PyObject n) { return range(Py.Zero, n, Py.One); } @@ -925,10 +902,7 @@ return range(start, stop, Py.One); } - /** - * Handle range() when PyLong arguments (that OverFlow ints) are given. - */ - private static PyObject handleRangeLongs(PyObject ilow, PyObject ihigh, PyObject istep) { + public static PyObject range(PyObject ilow, PyObject ihigh, PyObject istep) { ilow = getRangeLongArgument(ilow, "start"); ihigh = getRangeLongArgument(ihigh, "end"); istep = getRangeLongArgument(istep, "step"); @@ -949,8 +923,8 @@ PyObject[] range = new PyObject[n]; for (int i = 0; i < n; i++) { - range[i] = ilow.__long__(); - ilow = ilow.__add__(istep); + range[i] = ilow; + ilow = ilow._add(istep); } return new PyList(range); } @@ -1400,7 +1374,11 @@ @Override public PyObject __call__(PyObject arg1, PyObject arg2) { - return arg1.__format__(arg2); + PyObject formatted = arg1.__format__(arg2); + if (!Py.isInstance(formatted, PyString.TYPE) && !Py.isInstance(formatted, PyUnicode.TYPE) ) { + throw Py.TypeError("instance.__format__ must return string or unicode, not " + formatted.getType().fastGetName()); + } + return formatted; } } diff --git a/src/org/python/core/imp.java b/src/org/python/core/imp.java --- a/src/org/python/core/imp.java +++ b/src/org/python/core/imp.java @@ -127,6 +127,10 @@ return module; } module = new PyModule(name, null); + PyModule __builtin__ = (PyModule)modules.__finditem__("__builtin__"); + PyObject __dict__ = module.__getattr__("__dict__"); + __dict__.__setitem__("__builtins__", __builtin__.__getattr__("__dict__")); + __dict__.__setitem__("__package__", Py.None); modules.__setitem__(name, module); return module; } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sat Dec 13 03:19:18 2014 From: jython-checkins at python.org (jim.baker) Date: Sat, 13 Dec 2014 02:19:18 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Remove_skip_given_Jython_su?= =?utf-8?q?pports_SHA224_since_r6649?= Message-ID: <20141213021916.81005.38115@psf.io> https://hg.python.org/jython/rev/2ba1632717b4 changeset: 7446:2ba1632717b4 user: Jim Baker date: Fri Dec 12 19:18:59 2014 -0700 summary: Remove skip given Jython supports SHA224 since r6649 No need for a patched test_hmac given SHA224 support since https://hg.python.org/jython/rev/47c55317a1c9 files: Lib/test/test_hmac.py | 318 ------------------------------ 1 files changed, 0 insertions(+), 318 deletions(-) diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py deleted file mode 100644 --- a/Lib/test/test_hmac.py +++ /dev/null @@ -1,318 +0,0 @@ -import hmac -import hashlib -import unittest -import warnings -from test import test_support - -class TestVectorsTestCase(unittest.TestCase): - - def test_md5_vectors(self): - # Test the HMAC module against test vectors from the RFC. - - def md5test(key, data, digest): - h = hmac.HMAC(key, data) - self.assertEqual(h.hexdigest().upper(), digest.upper()) - - md5test(chr(0x0b) * 16, - "Hi There", - "9294727A3638BB1C13F48EF8158BFC9D") - - md5test("Jefe", - "what do ya want for nothing?", - "750c783e6ab0b503eaa86e310a5db738") - - md5test(chr(0xAA)*16, - chr(0xDD)*50, - "56be34521d144c88dbb8c733f0e8b3f6") - - md5test("".join([chr(i) for i in range(1, 26)]), - chr(0xCD) * 50, - "697eaf0aca3a3aea3a75164746ffaa79") - - md5test(chr(0x0C) * 16, - "Test With Truncation", - "56461ef2342edc00f9bab995690efd4c") - - md5test(chr(0xAA) * 80, - "Test Using Larger Than Block-Size Key - Hash Key First", - "6b1ab7fe4bd7bf8f0b62e6ce61b9d0cd") - - md5test(chr(0xAA) * 80, - ("Test Using Larger Than Block-Size Key " - "and Larger Than One Block-Size Data"), - "6f630fad67cda0ee1fb1f562db3aa53e") - - def test_sha_vectors(self): - def shatest(key, data, digest): - h = hmac.HMAC(key, data, digestmod=hashlib.sha1) - self.assertEqual(h.hexdigest().upper(), digest.upper()) - - shatest(chr(0x0b) * 20, - "Hi There", - "b617318655057264e28bc0b6fb378c8ef146be00") - - shatest("Jefe", - "what do ya want for nothing?", - "effcdf6ae5eb2fa2d27416d5f184df9c259a7c79") - - shatest(chr(0xAA)*20, - chr(0xDD)*50, - "125d7342b9ac11cd91a39af48aa17b4f63f175d3") - - shatest("".join([chr(i) for i in range(1, 26)]), - chr(0xCD) * 50, - "4c9007f4026250c6bc8414f9bf50c86c2d7235da") - - shatest(chr(0x0C) * 20, - "Test With Truncation", - "4c1a03424b55e07fe7f27be1d58bb9324a9a5a04") - - shatest(chr(0xAA) * 80, - "Test Using Larger Than Block-Size Key - Hash Key First", - "aa4ae5e15272d00e95705637ce8a3b55ed402112") - - shatest(chr(0xAA) * 80, - ("Test Using Larger Than Block-Size Key " - "and Larger Than One Block-Size Data"), - "e8e99d0f45237d786d6bbaa7965c7808bbff1a91") - - def _rfc4231_test_cases(self, hashfunc): - def hmactest(key, data, hexdigests): - h = hmac.HMAC(key, data, digestmod=hashfunc) - self.assertEqual(h.hexdigest().lower(), hexdigests[hashfunc]) - - # 4.2. Test Case 1 - hmactest(key = '\x0b'*20, - data = 'Hi There', - hexdigests = { - hashlib.sha224: '896fb1128abbdf196832107cd49df33f' - '47b4b1169912ba4f53684b22', - hashlib.sha256: 'b0344c61d8db38535ca8afceaf0bf12b' - '881dc200c9833da726e9376c2e32cff7', - hashlib.sha384: 'afd03944d84895626b0825f4ab46907f' - '15f9dadbe4101ec682aa034c7cebc59c' - 'faea9ea9076ede7f4af152e8b2fa9cb6', - hashlib.sha512: '87aa7cdea5ef619d4ff0b4241a1d6cb0' - '2379f4e2ce4ec2787ad0b30545e17cde' - 'daa833b7d6b8a702038b274eaea3f4e4' - 'be9d914eeb61f1702e696c203a126854', - }) - - # 4.3. Test Case 2 - hmactest(key = 'Jefe', - data = 'what do ya want for nothing?', - hexdigests = { - hashlib.sha224: 'a30e01098bc6dbbf45690f3a7e9e6d0f' - '8bbea2a39e6148008fd05e44', - hashlib.sha256: '5bdcc146bf60754e6a042426089575c7' - '5a003f089d2739839dec58b964ec3843', - hashlib.sha384: 'af45d2e376484031617f78d2b58a6b1b' - '9c7ef464f5a01b47e42ec3736322445e' - '8e2240ca5e69e2c78b3239ecfab21649', - hashlib.sha512: '164b7a7bfcf819e2e395fbe73b56e0a3' - '87bd64222e831fd610270cd7ea250554' - '9758bf75c05a994a6d034f65f8f0e6fd' - 'caeab1a34d4a6b4b636e070a38bce737', - }) - - # 4.4. Test Case 3 - hmactest(key = '\xaa'*20, - data = '\xdd'*50, - hexdigests = { - hashlib.sha224: '7fb3cb3588c6c1f6ffa9694d7d6ad264' - '9365b0c1f65d69d1ec8333ea', - hashlib.sha256: '773ea91e36800e46854db8ebd09181a7' - '2959098b3ef8c122d9635514ced565fe', - hashlib.sha384: '88062608d3e6ad8a0aa2ace014c8a86f' - '0aa635d947ac9febe83ef4e55966144b' - '2a5ab39dc13814b94e3ab6e101a34f27', - hashlib.sha512: 'fa73b0089d56a284efb0f0756c890be9' - 'b1b5dbdd8ee81a3655f83e33b2279d39' - 'bf3e848279a722c806b485a47e67c807' - 'b946a337bee8942674278859e13292fb', - }) - - # 4.5. Test Case 4 - hmactest(key = ''.join([chr(x) for x in xrange(0x01, 0x19+1)]), - data = '\xcd'*50, - hexdigests = { - hashlib.sha224: '6c11506874013cac6a2abc1bb382627c' - 'ec6a90d86efc012de7afec5a', - hashlib.sha256: '82558a389a443c0ea4cc819899f2083a' - '85f0faa3e578f8077a2e3ff46729665b', - hashlib.sha384: '3e8a69b7783c25851933ab6290af6ca7' - '7a9981480850009cc5577c6e1f573b4e' - '6801dd23c4a7d679ccf8a386c674cffb', - hashlib.sha512: 'b0ba465637458c6990e5a8c5f61d4af7' - 'e576d97ff94b872de76f8050361ee3db' - 'a91ca5c11aa25eb4d679275cc5788063' - 'a5f19741120c4f2de2adebeb10a298dd', - }) - - # 4.7. Test Case 6 - hmactest(key = '\xaa'*131, - data = 'Test Using Larger Than Block-Siz' - 'e Key - Hash Key First', - hexdigests = { - hashlib.sha224: '95e9a0db962095adaebe9b2d6f0dbce2' - 'd499f112f2d2b7273fa6870e', - hashlib.sha256: '60e431591ee0b67f0d8a26aacbf5b77f' - '8e0bc6213728c5140546040f0ee37f54', - hashlib.sha384: '4ece084485813e9088d2c63a041bc5b4' - '4f9ef1012a2b588f3cd11f05033ac4c6' - '0c2ef6ab4030fe8296248df163f44952', - hashlib.sha512: '80b24263c7c1a3ebb71493c1dd7be8b4' - '9b46d1f41b4aeec1121b013783f8f352' - '6b56d037e05f2598bd0fd2215d6a1e52' - '95e64f73f63f0aec8b915a985d786598', - }) - - # 4.8. Test Case 7 - hmactest(key = '\xaa'*131, - data = 'This is a test using a larger th' - 'an block-size key and a larger t' - 'han block-size data. The key nee' - 'ds to be hashed before being use' - 'd by the HMAC algorithm.', - hexdigests = { - hashlib.sha224: '3a854166ac5d9f023f54d517d0b39dbd' - '946770db9c2b95c9f6f565d1', - hashlib.sha256: '9b09ffa71b942fcb27635fbcd5b0e944' - 'bfdc63644f0713938a7f51535c3a35e2', - hashlib.sha384: '6617178e941f020d351e2f254e8fd32c' - '602420feb0b8fb9adccebb82461e99c5' - 'a678cc31e799176d3860e6110c46523e', - hashlib.sha512: 'e37b6a775dc87dbaa4dfa9f96e5e3ffd' - 'debd71f8867289865df5a32d20cdc944' - 'b6022cac3c4982b10d5eeb55c3e4de15' - '134676fb6de0446065c97440fa8c6a58', - }) - - def test_sha224_rfc4231(self): - self._rfc4231_test_cases(hashlib.sha224) - - def test_sha256_rfc4231(self): - self._rfc4231_test_cases(hashlib.sha256) - - def test_sha384_rfc4231(self): - self._rfc4231_test_cases(hashlib.sha384) - - def test_sha512_rfc4231(self): - self._rfc4231_test_cases(hashlib.sha512) - - def test_legacy_block_size_warnings(self): - class MockCrazyHash(object): - """Ain't no block_size attribute here.""" - def __init__(self, *args): - self._x = hashlib.sha1(*args) - self.digest_size = self._x.digest_size - def update(self, v): - self._x.update(v) - def digest(self): - return self._x.digest() - - with warnings.catch_warnings(): - warnings.simplefilter('error', RuntimeWarning) - with self.assertRaises(RuntimeWarning): - hmac.HMAC('a', 'b', digestmod=MockCrazyHash) - self.fail('Expected warning about missing block_size') - - MockCrazyHash.block_size = 1 - with self.assertRaises(RuntimeWarning): - hmac.HMAC('a', 'b', digestmod=MockCrazyHash) - self.fail('Expected warning about small block_size') - - - -class ConstructorTestCase(unittest.TestCase): - - def test_normal(self): - # Standard constructor call. - failed = 0 - try: - h = hmac.HMAC("key") - except: - self.fail("Standard constructor call raised exception.") - - def test_withtext(self): - # Constructor call with text. - try: - h = hmac.HMAC("key", "hash this!") - except: - self.fail("Constructor call with text argument raised exception.") - - def test_withmodule(self): - # Constructor call with text and digest module. - try: - h = hmac.HMAC("key", "", hashlib.sha1) - except: - self.fail("Constructor call with hashlib.sha1 raised exception.") - -class SanityTestCase(unittest.TestCase): - - def test_default_is_md5(self): - # Testing if HMAC defaults to MD5 algorithm. - # NOTE: this whitebox test depends on the hmac class internals - h = hmac.HMAC("key") - self.assertTrue(h.digest_cons == hashlib.md5) - - def test_exercise_all_methods(self): - # Exercising all methods once. - # This must not raise any exceptions - try: - h = hmac.HMAC("my secret key") - h.update("compute the hash of this text!") - dig = h.digest() - dig = h.hexdigest() - h2 = h.copy() - except: - self.fail("Exception raised during normal usage of HMAC class.") - -class CopyTestCase(unittest.TestCase): - - def test_attributes(self): - # Testing if attributes are of same type. - h1 = hmac.HMAC("key") - h2 = h1.copy() - self.assertTrue(h1.digest_cons == h2.digest_cons, - "digest constructors don't match.") - self.assertTrue(type(h1.inner) == type(h2.inner), - "Types of inner don't match.") - self.assertTrue(type(h1.outer) == type(h2.outer), - "Types of outer don't match.") - - def test_realcopy(self): - # Testing if the copy method created a real copy. - h1 = hmac.HMAC("key") - h2 = h1.copy() - # Using id() in case somebody has overridden __cmp__. - self.assertTrue(id(h1) != id(h2), "No real copy of the HMAC instance.") - self.assertTrue(id(h1.inner) != id(h2.inner), - "No real copy of the attribute 'inner'.") - self.assertTrue(id(h1.outer) != id(h2.outer), - "No real copy of the attribute 'outer'.") - - def test_equality(self): - # Testing if the copy has the same digests. - h1 = hmac.HMAC("key") - h1.update("some random text") - h2 = h1.copy() - self.assertTrue(h1.digest() == h2.digest(), - "Digest of copy doesn't match original digest.") - self.assertTrue(h1.hexdigest() == h2.hexdigest(), - "Hexdigest of copy doesn't match original hexdigest.") - -def test_main(): - if test_support.is_jython: - # XXX: Jython doesn't support sha224 - del TestVectorsTestCase.test_sha224_rfc4231 - hashlib.sha224 = None - test_support.run_unittest( - TestVectorsTestCase, - ConstructorTestCase, - SanityTestCase, - CopyTestCase - ) - -if __name__ == "__main__": - test_main() -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sat Dec 13 03:35:25 2014 From: jython-checkins at python.org (jim.baker) Date: Sat, 13 Dec 2014 02:35:25 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Re-enable_memory_leak_test_?= =?utf-8?q?on_type_=3C-=3E_class_mapping?= Message-ID: <20141213023523.92257.3712@psf.io> https://hg.python.org/jython/rev/fe4e370f23c2 changeset: 7447:fe4e370f23c2 user: Jim Baker date: Fri Dec 12 19:35:19 2014 -0700 summary: Re-enable memory leak test on type <-> class mapping Taking the len (or size()) of weak maps in Google Guava is eventually consistent, which is not terribly useful for our testing. Instead we measure this by taking the length of its keys (or other iterable). files: Lib/test/test_jy_internals.py | 26 +++++++--------------- 1 files changed, 9 insertions(+), 17 deletions(-) diff --git a/Lib/test/test_jy_internals.py b/Lib/test/test_jy_internals.py --- a/Lib/test/test_jy_internals.py +++ b/Lib/test/test_jy_internals.py @@ -1,7 +1,6 @@ """ test some jython internals """ -import gc import unittest import time from test import test_support @@ -18,7 +17,6 @@ class MemoryLeakTests(unittest.TestCase): - @unittest.skip("FIXME: broken in 2.7.") def test_class_to_test_weakness(self): # regrtest for bug 1522, adapted from test code submitted by Matt Brinkley @@ -28,16 +26,6 @@ # `type`!) class_to_type_map = getField(type, 'class_to_type').get(None) - def make_clean(): - # gc a few times just to be really sure, since in this - # case we don't really care if it takes a few cycles of GC - # for the garbage to be reached - gc.collect() - time.sleep(0.1) - gc.collect() - time.sleep(0.5) - gc.collect() - def create_proxies(): pi = PythonInterpreter() pi.exec(""" @@ -51,16 +39,20 @@ Dog().bark() """) - make_clean() - # get to steady state first, then verify we don't create new proxies for i in xrange(2): create_proxies() - start_size = class_to_type_map.size() + # Ensure the reaper thread can run and clear out weak refs, so + # use this supporting function + test_support.gc_collect() + # Given that taking the len (or size()) of Guava weak maps is + # eventually consistent, we should instead take a len of its + # keys. + start_size = len(list(class_to_type_map)) for i in xrange(5): create_proxies() - make_clean() - self.assertEqual(start_size, class_to_type_map.size()) + test_support.gc_collect() + self.assertEqual(start_size, len(list(class_to_type_map))) class WeakIdentityMapTests(unittest.TestCase): -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sat Dec 13 04:34:49 2014 From: jython-checkins at python.org (jim.baker) Date: Sat, 13 Dec 2014 03:34:49 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Update_test=5Ffuncattrs_and?= =?utf-8?q?_remove_skip?= Message-ID: <20141213033441.28228.88185@psf.io> https://hg.python.org/jython/rev/afe00ee26119 changeset: 7448:afe00ee26119 user: Jim Baker date: Fri Dec 12 20:34:37 2014 -0700 summary: Update test_funcattrs and remove skip Jython is more uniform in its attribute model than CPython. Unfortunately we have more tests depending on such attempted settings of read-only attributes resulting in a TypeError than an AttributeError. But fixing this rather pointless, so deferring to Jython 3.x. and changed test_funcattr accordingly. See http://bugs.python.org/issue1687163 files: Lib/test/test_funcattrs.py | 15 ++++++++------- 1 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Lib/test/test_funcattrs.py b/Lib/test/test_funcattrs.py --- a/Lib/test/test_funcattrs.py +++ b/Lib/test/test_funcattrs.py @@ -62,9 +62,7 @@ def test_func_globals(self): self.assertIs(self.b.func_globals, globals()) - self.assertIs(self.b.__globals__, globals()) self.cannot_set_attr(self.b, 'func_globals', 2, TypeError) - self.cannot_set_attr(self.b, '__globals__', 2, TypeError) def test_func_closure(self): a = 12 @@ -150,10 +148,8 @@ return a+b self.assertEqual(first_func.func_defaults, None) self.assertEqual(second_func.func_defaults, (1, 2)) - self.assertEqual(second_func.func_defaults, second_func.__defaults__) first_func.func_defaults = (1, 2) self.assertEqual(first_func.func_defaults, (1, 2)) - self.assertEqual(first_func.func_defaults, first_func.__defaults__) self.assertEqual(first_func(), 3) self.assertEqual(first_func(3), 5) self.assertEqual(first_func(3, 5), 8) @@ -312,7 +308,6 @@ class FunctionDocstringTest(FuncAttrsTest): - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") def test_set_docstring_attr(self): self.assertEqual(self.b.__doc__, None) self.assertEqual(self.b.func_doc, None) @@ -322,8 +317,14 @@ self.assertEqual(self.b.func_doc, docstr) self.assertEqual(self.f.a.__doc__, docstr) self.assertEqual(self.fi.a.__doc__, docstr) - self.cannot_set_attr(self.f.a, "__doc__", docstr, AttributeError) - self.cannot_set_attr(self.fi.a, "__doc__", docstr, AttributeError) + # Jython is more uniform in its attribute model than CPython. + # Unfortunately we have more tests depending on such attempted + # settings of read-only attributes resulting in a TypeError + # than an AttributeError. But fixing this seems pointless for + # now, deferring to Jython 3.x. See + # http://bugs.python.org/issue1687163 + self.cannot_set_attr(self.f.a, "__doc__", docstr, TypeError) + self.cannot_set_attr(self.fi.a, "__doc__", docstr, TypeError) def test_delete_docstring(self): self.b.__doc__ = "The docstring" -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sat Dec 13 05:20:09 2014 From: jython-checkins at python.org (jim.baker) Date: Sat, 13 Dec 2014 04:20:09 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Fix_time=2Easctime_so_it_ch?= =?utf-8?q?ecks_its_time_sequence_param?= Message-ID: <20141213042009.28222.66216@psf.io> https://hg.python.org/jython/rev/f55d51fa7843 changeset: 7449:f55d51fa7843 user: Jim Baker date: Fri Dec 12 21:20:04 2014 -0700 summary: Fix time.asctime so it checks its time sequence param files: Lib/test/test_time.py | 1 - src/org/python/modules/time/Time.java | 15 ++++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py --- a/Lib/test/test_time.py +++ b/Lib/test/test_time.py @@ -125,7 +125,6 @@ except ValueError: self.fail('strptime failed on empty args.') - @unittest.skip("FIXME: broken") def test_asctime(self): time.asctime(time.gmtime(self.t)) self.assertRaises(TypeError, time.asctime, 0) diff --git a/src/org/python/modules/time/Time.java b/src/org/python/modules/time/Time.java --- a/src/org/python/modules/time/Time.java +++ b/src/org/python/modules/time/Time.java @@ -31,6 +31,7 @@ import org.python.core.PyException; import org.python.core.PyInteger; import org.python.core.PyObject; +import org.python.core.PySequence; import org.python.core.PyString; import org.python.core.PyTuple; import org.python.core.__builtin__; @@ -412,7 +413,19 @@ return asctime(localtime()); } - public static PyString asctime(PyTuple tup) { + public static PyString asctime(PyObject obj) { + PyTuple tup; + if (obj instanceof PyTuple) { + tup = (PyTuple)obj; + } else { + tup = PyTuple.fromIterable(obj); + } + int len = tup.__len__(); + if (len != 9) { + throw Py.TypeError( + String.format("argument must be sequence of length 9, not %d", len)); + } + StringBuilder buf = new StringBuilder(25); buf.append(enshortdays[item(tup, 6)]).append(' '); buf.append(enshortmonths[item(tup, 1)]).append(' '); -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sat Dec 13 05:26:12 2014 From: jython-checkins at python.org (jim.baker) Date: Sat, 13 Dec 2014 04:26:12 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_For_now=2C_skip_testing_of_?= =?utf-8?q?test=5Fcmd=5Fline=5Fscript?= Message-ID: <20141213042611.18145.64620@psf.io> https://hg.python.org/jython/rev/aa6cafe2a1cc changeset: 7450:aa6cafe2a1cc user: Jim Baker date: Fri Dec 12 21:26:05 2014 -0700 summary: For now, skip testing of test_cmd_line_script Command line testing is currently too hard. Triaging for a possible post 2.7.0 fix. files: Lib/test/regrtest.py | 3 +++ 1 files changed, 3 insertions(+), 0 deletions(-) diff --git a/Lib/test/regrtest.py b/Lib/test/regrtest.py --- a/Lib/test/regrtest.py +++ b/Lib/test/regrtest.py @@ -1266,6 +1266,9 @@ test_asynchat test_asyncore + # Command line testing is hard for Jython to do, but revisit + test_cmd_line_script + # Tests that should work with socket-reboot, but currently hang test_ftplib test_httplib -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 14 08:59:49 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 14 Dec 2014 07:59:49 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Fix_problems_in_bz2=2C_gzip?= =?utf-8?q?=2C_and_tarfile?= Message-ID: <20141214075948.50117.3491@psf.io> https://hg.python.org/jython/rev/241308c580fc changeset: 7451:241308c580fc user: Jim Baker date: Sun Dec 14 00:59:44 2014 -0700 summary: Fix problems in bz2, gzip, and tarfile Update tarfile so that it has latest bug fixes from 2.7. Fix bz2 such that PyBZ2Decompressor does inadvertently use charset decoding when constructing from bytes. Fix gzip so __enter__ properly checks if closed (derivation was checking the wrong property implementation). Remove Jython-specific test_tarfile files: Lib/gzip.py | 106 +- Lib/tarfile.py | 43 +- Lib/test/test_tarfile.py | 1568 ---------- src/org/python/core/PyByteArray.java | 4 +- src/org/python/modules/bz2/PyBZ2Decompressor.java | 11 +- 5 files changed, 85 insertions(+), 1647 deletions(-) diff --git a/Lib/gzip.py b/Lib/gzip.py --- a/Lib/gzip.py +++ b/Lib/gzip.py @@ -21,9 +21,6 @@ # or unsigned. output.write(struct.pack(" self.extrasize: - self._read(readsize) - readsize = min(self.max_read_chunk, readsize * 2) - except EOFError: - if size > self.extrasize: - size = self.extrasize + while size > self.extrasize: + if not self._read(readsize): + if size > self.extrasize: + size = self.extrasize + break + readsize = min(self.max_read_chunk, readsize * 2) offset = self.offset - self.extrastart chunk = self.extrabuf[offset: offset + size] @@ -272,7 +274,7 @@ def _read(self, size=1024): if self.fileobj is None: - raise EOFError, "Reached EOF" + return False if self._new_member: # If the _new_member flag is set, we have to @@ -283,7 +285,7 @@ pos = self.fileobj.tell() # Save current position self.fileobj.seek(0, 2) # Seek to end of file if pos == self.fileobj.tell(): - raise EOFError, "Reached EOF" + return False else: self.fileobj.seek( pos ) # Return to original position @@ -300,9 +302,10 @@ if buf == "": uncompress = self.decompress.flush() + self.fileobj.seek(-len(self.decompress.unused_data), 1) self._read_eof() self._add_read_data( uncompress ) - raise EOFError, 'Reached EOF' + return False uncompress = self.decompress.decompress(buf) self._add_read_data( uncompress ) @@ -312,13 +315,14 @@ # so seek back to the start of the unused data, finish up # this member, and read a new gzip header. # (The number of bytes to seek back is the length of the unused - # data, minus 8 because _read_eof() will rewind a further 8 bytes) - self.fileobj.seek( -len(self.decompress.unused_data)+8, 1) + # data) + self.fileobj.seek(-len(self.decompress.unused_data), 1) # Check the CRC and file size, and set the flag so we read # a new member on the next call self._read_eof() self._new_member = True + return True def _add_read_data(self, data): self.crc = zlib.crc32(data, self.crc) & 0xffffffffL @@ -329,14 +333,11 @@ self.size = self.size + len(data) def _read_eof(self): - # We've read to the end of the file, so we have to rewind in order - # to reread the 8 bytes containing the CRC and the file size. + # We've read to the end of the file. # We check the that the computed CRC and size of the # uncompressed data matches the stored values. Note that the size # stored is the true file size mod 2**32. - self.fileobj.seek(-8, 1) - crc32 = read32(self.fileobj) - isize = read32(self.fileobj) # may exceed 2GB + crc32, isize = struct.unpack(" 0, - "tarfile is empty") - - # The test_*_size tests test for bug #1167128. - def test_file_size(self): - tar = tarfile.open(tmpname, self.mode) - - path = os.path.join(TEMPDIR, "file") - fobj = open(path, "wb") - fobj.close() - tarinfo = tar.gettarinfo(path) - self.assertEqual(tarinfo.size, 0) - - fobj = open(path, "wb") - fobj.write("aaa") - fobj.close() - tarinfo = tar.gettarinfo(path) - self.assertEqual(tarinfo.size, 3) - - tar.close() - - def test_directory_size(self): - path = os.path.join(TEMPDIR, "directory") - os.mkdir(path) - try: - tar = tarfile.open(tmpname, self.mode) - tarinfo = tar.gettarinfo(path) - self.assertEqual(tarinfo.size, 0) - finally: - os.rmdir(path) - - def test_link_size(self): - if hasattr(os, "link"): - link = os.path.join(TEMPDIR, "link") - target = os.path.join(TEMPDIR, "link_target") - fobj = open(target, "wb") - fobj.write("aaa") - fobj.close() - os.link(target, link) - try: - tar = tarfile.open(tmpname, self.mode) - # Record the link target in the inodes list. - tar.gettarinfo(target) - tarinfo = tar.gettarinfo(link) - self.assertEqual(tarinfo.size, 0) - finally: - os.remove(target) - os.remove(link) - - def test_symlink_size(self): - if hasattr(os, "symlink"): - path = os.path.join(TEMPDIR, "symlink") - os.symlink("link_target", path) - try: - tar = tarfile.open(tmpname, self.mode) - tarinfo = tar.gettarinfo(path) - self.assertEqual(tarinfo.size, 0) - finally: - os.remove(path) - - def test_add_self(self): - # Test for #1257255. - dstname = os.path.abspath(tmpname) - - tar = tarfile.open(tmpname, self.mode) - self.assertTrue(tar.name == dstname, "archive name must be absolute") - - tar.add(dstname) - self.assertTrue(tar.getnames() == [], "added the archive to itself") - - cwd = os.getcwd() - os.chdir(TEMPDIR) - tar.add(dstname) - os.chdir(cwd) - self.assertTrue(tar.getnames() == [], "added the archive to itself") - - def test_exclude(self): - tempdir = os.path.join(TEMPDIR, "exclude") - os.mkdir(tempdir) - try: - for name in ("foo", "bar", "baz"): - name = os.path.join(tempdir, name) - open(name, "wb").close() - - exclude = os.path.isfile - - tar = tarfile.open(tmpname, self.mode, encoding="iso8859-1") - with test_support.check_warnings(("use the filter argument", - DeprecationWarning)): - tar.add(tempdir, arcname="empty_dir", exclude=exclude) - tar.close() - - tar = tarfile.open(tmpname, "r") - self.assertEqual(len(tar.getmembers()), 1) - self.assertEqual(tar.getnames()[0], "empty_dir") - finally: - shutil.rmtree(tempdir) - - def test_filter(self): - tempdir = os.path.join(TEMPDIR, "filter") - os.mkdir(tempdir) - try: - for name in ("foo", "bar", "baz"): - name = os.path.join(tempdir, name) - open(name, "wb").close() - - def filter(tarinfo): - if os.path.basename(tarinfo.name) == "bar": - return - tarinfo.uid = 123 - tarinfo.uname = "foo" - return tarinfo - - tar = tarfile.open(tmpname, self.mode, encoding="iso8859-1") - tar.add(tempdir, arcname="empty_dir", filter=filter) - tar.close() - - tar = tarfile.open(tmpname, "r") - for tarinfo in tar: - self.assertEqual(tarinfo.uid, 123) - self.assertEqual(tarinfo.uname, "foo") - self.assertEqual(len(tar.getmembers()), 3) - tar.close() - finally: - shutil.rmtree(tempdir) - - # Guarantee that stored pathnames are not modified. Don't - # remove ./ or ../ or double slashes. Still make absolute - # pathnames relative. - # For details see bug #6054. - def _test_pathname(self, path, cmp_path=None, dir=False): - # Create a tarfile with an empty member named path - # and compare the stored name with the original. - foo = os.path.join(TEMPDIR, "foo") - if not dir: - open(foo, "w").close() - else: - os.mkdir(foo) - - tar = tarfile.open(tmpname, self.mode) - tar.add(foo, arcname=path) - tar.close() - - tar = tarfile.open(tmpname, "r") - t = tar.next() - tar.close() - - if not dir: - os.remove(foo) - else: - os.rmdir(foo) - - self.assertEqual(t.name, cmp_path or path.replace(os.sep, "/")) - - def test_pathnames(self): - self._test_pathname("foo") - self._test_pathname(os.path.join("foo", ".", "bar")) - self._test_pathname(os.path.join("foo", "..", "bar")) - self._test_pathname(os.path.join(".", "foo")) - self._test_pathname(os.path.join(".", "foo", ".")) - self._test_pathname(os.path.join(".", "foo", ".", "bar")) - self._test_pathname(os.path.join(".", "foo", "..", "bar")) - self._test_pathname(os.path.join(".", "foo", "..", "bar")) - self._test_pathname(os.path.join("..", "foo")) - self._test_pathname(os.path.join("..", "foo", "..")) - self._test_pathname(os.path.join("..", "foo", ".", "bar")) - self._test_pathname(os.path.join("..", "foo", "..", "bar")) - - self._test_pathname("foo" + os.sep + os.sep + "bar") - self._test_pathname("foo" + os.sep + os.sep, "foo", dir=True) - - def test_abs_pathnames(self): - if sys.platform == "win32": - self._test_pathname("C:\\foo", "foo") - else: - self._test_pathname("/foo", "foo") - self._test_pathname("///foo", "foo") - - def test_cwd(self): - # Test adding the current working directory. - cwd = os.getcwd() - os.chdir(TEMPDIR) - try: - open("foo", "w").close() - - tar = tarfile.open(tmpname, self.mode) - tar.add(".") - tar.close() - - tar = tarfile.open(tmpname, "r") - for t in tar: - self.assert_(t.name == "." or t.name.startswith("./")) - tar.close() - finally: - os.chdir(cwd) - - -class StreamWriteTest(WriteTestBase): - - mode = "w|" - - def test_stream_padding(self): - # Test for bug #1543303. - tar = tarfile.open(tmpname, self.mode) - tar.close() - - if self.mode.endswith("gz"): - fobj = gzip.GzipFile(tmpname) - data = fobj.read() - fobj.close() - elif self.mode.endswith("bz2"): - dec = bz2.BZ2Decompressor() - data = open(tmpname, "rb").read() - data = dec.decompress(data) - self.assertTrue(len(dec.unused_data) == 0, - "found trailing data") - else: - fobj = open(tmpname, "rb") - data = fobj.read() - fobj.close() - - self.assertTrue(data.count("\0") == tarfile.RECORDSIZE, - "incorrect zero padding") - - def test_file_mode(self): - # Test for issue #8464: Create files with correct - # permissions. - if sys.platform == "win32" or not hasattr(os, "umask"): - return - - if os.path.exists(tmpname): - os.remove(tmpname) - - original_umask = os.umask(0022) - try: - tar = tarfile.open(tmpname, self.mode) - tar.close() - mode = os.stat(tmpname).st_mode & 0777 - self.assertEqual(mode, 0644, "wrong file permissions") - finally: - os.umask(original_umask) - - -class GNUWriteTest(unittest.TestCase): - # This testcase checks for correct creation of GNU Longname - # and Longlink extended headers (cp. bug #812325). - - def _length(self, s): - blocks, remainder = divmod(len(s) + 1, 512) - if remainder: - blocks += 1 - return blocks * 512 - - def _calc_size(self, name, link=None): - # Initial tar header - count = 512 - - if len(name) > tarfile.LENGTH_NAME: - # GNU longname extended header + longname - count += 512 - count += self._length(name) - if link is not None and len(link) > tarfile.LENGTH_LINK: - # GNU longlink extended header + longlink - count += 512 - count += self._length(link) - return count - - def _test(self, name, link=None): - tarinfo = tarfile.TarInfo(name) - if link: - tarinfo.linkname = link - tarinfo.type = tarfile.LNKTYPE - - tar = tarfile.open(tmpname, "w") - tar.format = tarfile.GNU_FORMAT - tar.addfile(tarinfo) - - v1 = self._calc_size(name, link) - v2 = tar.offset - self.assertTrue(v1 == v2, "GNU longname/longlink creation failed") - - tar.close() - - tar = tarfile.open(tmpname) - member = tar.next() - self.assertIsNotNone(member, - "unable to read longname member") - self.assertEqual(tarinfo.name, member.name, - "unable to read longname member") - self.assertEqual(tarinfo.linkname, member.linkname, - "unable to read longname member") - tar.close() - - def test_longname_1023(self): - self._test(("longnam/" * 127) + "longnam") - - def test_longname_1024(self): - self._test(("longnam/" * 127) + "longname") - - def test_longname_1025(self): - self._test(("longnam/" * 127) + "longname_") - - def test_longlink_1023(self): - self._test("name", ("longlnk/" * 127) + "longlnk") - - def test_longlink_1024(self): - self._test("name", ("longlnk/" * 127) + "longlink") - - def test_longlink_1025(self): - self._test("name", ("longlnk/" * 127) + "longlink_") - - def test_longnamelink_1023(self): - self._test(("longnam/" * 127) + "longnam", - ("longlnk/" * 127) + "longlnk") - - def test_longnamelink_1024(self): - self._test(("longnam/" * 127) + "longname", - ("longlnk/" * 127) + "longlink") - - def test_longnamelink_1025(self): - self._test(("longnam/" * 127) + "longname_", - ("longlnk/" * 127) + "longlink_") - - -class HardlinkTest(unittest.TestCase): - # Test the creation of LNKTYPE (hardlink) members in an archive. - - def setUp(self): - self.foo = os.path.join(TEMPDIR, "foo") - self.bar = os.path.join(TEMPDIR, "bar") - - fobj = open(self.foo, "wb") - fobj.write("foo") - fobj.close() - - os.link(self.foo, self.bar) - - self.tar = tarfile.open(tmpname, "w") - self.tar.add(self.foo) - - def tearDown(self): - self.tar.close() - os.remove(self.foo) - os.remove(self.bar) - - def test_add_twice(self): - # The same name will be added as a REGTYPE every - # time regardless of st_nlink. - tarinfo = self.tar.gettarinfo(self.foo) - self.assertTrue(tarinfo.type == tarfile.REGTYPE, - "add file as regular failed") - - def test_add_hardlink(self): - tarinfo = self.tar.gettarinfo(self.bar) - self.assertTrue(tarinfo.type == tarfile.LNKTYPE, - "add file as hardlink failed") - - def test_dereference_hardlink(self): - self.tar.dereference = True - tarinfo = self.tar.gettarinfo(self.bar) - self.assertTrue(tarinfo.type == tarfile.REGTYPE, - "dereferencing hardlink failed") - - -class PaxWriteTest(GNUWriteTest): - - def _test(self, name, link=None): - # See GNUWriteTest. - tarinfo = tarfile.TarInfo(name) - if link: - tarinfo.linkname = link - tarinfo.type = tarfile.LNKTYPE - - tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT) - tar.addfile(tarinfo) - tar.close() - - tar = tarfile.open(tmpname) - if link: - l = tar.getmembers()[0].linkname - self.assertTrue(link == l, "PAX longlink creation failed") - else: - n = tar.getmembers()[0].name - self.assertTrue(name == n, "PAX longname creation failed") - - def test_pax_global_header(self): - pax_headers = { - u"foo": u"bar", - u"uid": u"0", - u"mtime": u"1.23", - u"test": u"???", - u"???": u"test"} - - tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT, - pax_headers=pax_headers) - tar.addfile(tarfile.TarInfo("test")) - tar.close() - - # Test if the global header was written correctly. - tar = tarfile.open(tmpname, encoding="iso8859-1") - self.assertEqual(tar.pax_headers, pax_headers) - self.assertEqual(tar.getmembers()[0].pax_headers, pax_headers) - - # Test if all the fields are unicode. - for key, val in tar.pax_headers.iteritems(): - self.assertTrue(type(key) is unicode) - self.assertTrue(type(val) is unicode) - if key in tarfile.PAX_NUMBER_FIELDS: - try: - tarfile.PAX_NUMBER_FIELDS[key](val) - except (TypeError, ValueError): - self.fail("unable to convert pax header field") - - def test_pax_extended_header(self): - # The fields from the pax header have priority over the - # TarInfo. - pax_headers = {u"path": u"foo", u"uid": u"123"} - - tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT, encoding="iso8859-1") - t = tarfile.TarInfo() - t.name = u"???" # non-ASCII - t.uid = 8**8 # too large - t.pax_headers = pax_headers - tar.addfile(t) - tar.close() - - tar = tarfile.open(tmpname, encoding="iso8859-1") - t = tar.getmembers()[0] - self.assertEqual(t.pax_headers, pax_headers) - self.assertEqual(t.name, "foo") - self.assertEqual(t.uid, 123) - - -class UstarUnicodeTest(unittest.TestCase): - # All *UnicodeTests FIXME - - format = tarfile.USTAR_FORMAT - - def test_iso8859_1_filename(self): - self._test_unicode_filename("iso8859-1") - - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") - def test_utf7_filename(self): - self._test_unicode_filename("utf7") - - def test_utf8_filename(self): - self._test_unicode_filename("utf8") - - def _test_unicode_filename(self, encoding): - tar = tarfile.open(tmpname, "w", format=self.format, encoding=encoding, errors="strict") - name = u"???" - tar.addfile(tarfile.TarInfo(name)) - tar.close() - - tar = tarfile.open(tmpname, encoding=encoding) - self.assertTrue(type(tar.getnames()[0]) is not unicode) - self.assertEqual(tar.getmembers()[0].name, name.encode(encoding)) - tar.close() - - def test_unicode_filename_error(self): - tar = tarfile.open(tmpname, "w", format=self.format, encoding="ascii", errors="strict") - tarinfo = tarfile.TarInfo() - - tarinfo.name = "???" - if self.format == tarfile.PAX_FORMAT: - self.assertRaises(UnicodeError, tar.addfile, tarinfo) - else: - tar.addfile(tarinfo) - - tarinfo.name = u"???" - self.assertRaises(UnicodeError, tar.addfile, tarinfo) - - tarinfo.name = "foo" - tarinfo.uname = u"???" - self.assertRaises(UnicodeError, tar.addfile, tarinfo) - - def test_unicode_argument(self): - tar = tarfile.open(tarname, "r", encoding="iso8859-1", errors="strict") - for t in tar: - self.assertTrue(type(t.name) is str) - self.assertTrue(type(t.linkname) is str) - self.assertTrue(type(t.uname) is str) - self.assertTrue(type(t.gname) is str) - tar.close() - - def test_uname_unicode(self): - for name in (u"???", "???"): - t = tarfile.TarInfo("foo") - t.uname = name - t.gname = name - - fobj = StringIO.StringIO() - tar = tarfile.open("foo.tar", mode="w", fileobj=fobj, format=self.format, encoding="iso8859-1") - tar.addfile(t) - tar.close() - fobj.seek(0) - - tar = tarfile.open("foo.tar", fileobj=fobj, encoding="iso8859-1") - t = tar.getmember("foo") - self.assertEqual(t.uname, "???") - self.assertEqual(t.gname, "???") - - -class GNUUnicodeTest(UstarUnicodeTest): - - format = tarfile.GNU_FORMAT - - -class PaxUnicodeTest(UstarUnicodeTest): - - format = tarfile.PAX_FORMAT - - def _create_unicode_name(self, name): - tar = tarfile.open(tmpname, "w", format=self.format) - t = tarfile.TarInfo() - t.pax_headers["path"] = name - tar.addfile(t) - tar.close() - - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") - def test_error_handlers(self): - # Test if the unicode error handlers work correctly for characters - # that cannot be expressed in a given encoding. - self._create_unicode_name(u"???") - - for handler, name in (("utf-8", u"???".encode("utf8")), - ("replace", "???"), ("ignore", "")): - tar = tarfile.open(tmpname, format=self.format, encoding="ascii", - errors=handler) - self.assertEqual(tar.getnames()[0], name) - - self.assertRaises(UnicodeError, tarfile.open, tmpname, - encoding="ascii", errors="strict") - - def test_error_handler_utf8(self): - # Create a pathname that has one component representable using - # iso8859-1 and the other only in iso8859-15. - self._create_unicode_name(u"???/?") - - tar = tarfile.open(tmpname, format=self.format, encoding="iso8859-1", - errors="utf-8") - self.assertEqual(tar.getnames()[0], "???/" + u"?".encode("utf8")) - - -class AppendTest(unittest.TestCase): - # Test append mode (cp. patch #1652681). - - def setUp(self): - self.tarname = tmpname - if os.path.exists(self.tarname): - os.remove(self.tarname) - - def _add_testfile(self, fileobj=None): - tar = tarfile.open(self.tarname, "a", fileobj=fileobj) - tar.addfile(tarfile.TarInfo("bar")) - tar.close() - - def _create_testtar(self, mode="w:"): - src = tarfile.open(tarname, encoding="iso8859-1") - t = src.getmember("ustar/regtype") - t.name = "foo" - f = src.extractfile(t) - tar = tarfile.open(self.tarname, mode) - tar.addfile(t, f) - tar.close() - - def _test(self, names=["bar"], fileobj=None): - tar = tarfile.open(self.tarname, fileobj=fileobj) - self.assertEqual(tar.getnames(), names) - - def test_non_existing(self): - self._add_testfile() - self._test() - - def test_empty(self): - tarfile.open(self.tarname, "w:").close() - self._add_testfile() - self._test() - - def test_empty_fileobj(self): - fobj = StringIO.StringIO("\0" * 1024) - self._add_testfile(fobj) - fobj.seek(0) - self._test(fileobj=fobj) - - def test_fileobj(self): - self._create_testtar() - data = open(self.tarname).read() - fobj = StringIO.StringIO(data) - self._add_testfile(fobj) - fobj.seek(0) - self._test(names=["foo", "bar"], fileobj=fobj) - - def test_existing(self): - self._create_testtar() - self._add_testfile() - self._test(names=["foo", "bar"]) - - def test_append_gz(self): - if gzip is None: - return - self._create_testtar("w:gz") - self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, "a") - - def test_append_bz2(self): - if bz2 is None: - return - self._create_testtar("w:bz2") - self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, "a") - - # Append mode is supposed to fail if the tarfile to append to - # does not end with a zero block. - def _test_error(self, data): - open(self.tarname, "wb").write(data) - self.assertRaises(tarfile.ReadError, self._add_testfile) - - def test_null(self): - self._test_error("") - - def test_incomplete(self): - self._test_error("\0" * 13) - - def test_premature_eof(self): - data = tarfile.TarInfo("foo").tobuf() - self._test_error(data) - - def test_trailing_garbage(self): - data = tarfile.TarInfo("foo").tobuf() - self._test_error(data + "\0" * 13) - - def test_invalid(self): - self._test_error("a" * 512) - - -class LimitsTest(unittest.TestCase): - - def test_ustar_limits(self): - # 100 char name - tarinfo = tarfile.TarInfo("0123456789" * 10) - tarinfo.tobuf(tarfile.USTAR_FORMAT) - - # 101 char name that cannot be stored - tarinfo = tarfile.TarInfo("0123456789" * 10 + "0") - self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT) - - # 256 char name with a slash at pos 156 - tarinfo = tarfile.TarInfo("123/" * 62 + "longname") - tarinfo.tobuf(tarfile.USTAR_FORMAT) - - # 256 char name that cannot be stored - tarinfo = tarfile.TarInfo("1234567/" * 31 + "longname") - self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT) - - # 512 char name - tarinfo = tarfile.TarInfo("123/" * 126 + "longname") - self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT) - - # 512 char linkname - tarinfo = tarfile.TarInfo("longlink") - tarinfo.linkname = "123/" * 126 + "longname" - self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT) - - # uid > 8 digits - tarinfo = tarfile.TarInfo("name") - tarinfo.uid = 010000000 - self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT) - - def test_gnu_limits(self): - tarinfo = tarfile.TarInfo("123/" * 126 + "longname") - tarinfo.tobuf(tarfile.GNU_FORMAT) - - tarinfo = tarfile.TarInfo("longlink") - tarinfo.linkname = "123/" * 126 + "longname" - tarinfo.tobuf(tarfile.GNU_FORMAT) - - # uid >= 256 ** 7 - tarinfo = tarfile.TarInfo("name") - tarinfo.uid = 04000000000000000000L - self.assertRaises(ValueError, tarinfo.tobuf, tarfile.GNU_FORMAT) - - def test_pax_limits(self): - tarinfo = tarfile.TarInfo("123/" * 126 + "longname") - tarinfo.tobuf(tarfile.PAX_FORMAT) - - tarinfo = tarfile.TarInfo("longlink") - tarinfo.linkname = "123/" * 126 + "longname" - tarinfo.tobuf(tarfile.PAX_FORMAT) - - tarinfo = tarfile.TarInfo("name") - tarinfo.uid = 04000000000000000000L - tarinfo.tobuf(tarfile.PAX_FORMAT) - - -class ContextManagerTest(unittest.TestCase): - - def test_basic(self): - with tarfile.open(tarname) as tar: - self.assertFalse(tar.closed, "closed inside runtime context") - self.assertTrue(tar.closed, "context manager failed") - - def test_closed(self): - # The __enter__() method is supposed to raise IOError - # if the TarFile object is already closed. - tar = tarfile.open(tarname) - tar.close() - with self.assertRaises(IOError): - with tar: - pass - - def test_exception(self): - # Test if the IOError exception is passed through properly. - with self.assertRaises(Exception) as exc: - with tarfile.open(tarname) as tar: - raise IOError - self.assertIsInstance(exc.exception, IOError, - "wrong exception raised in context manager") - self.assertTrue(tar.closed, "context manager failed") - - def test_no_eof(self): - # __exit__() must not write end-of-archive blocks if an - # exception was raised. - try: - with tarfile.open(tmpname, "w") as tar: - raise Exception - except: - pass - self.assertEqual(os.path.getsize(tmpname), 0, - "context manager wrote an end-of-archive block") - self.assertTrue(tar.closed, "context manager failed") - - def test_eof(self): - # __exit__() must write end-of-archive blocks, i.e. call - # TarFile.close() if there was no error. - with tarfile.open(tmpname, "w"): - pass - self.assertNotEqual(os.path.getsize(tmpname), 0, - "context manager wrote no end-of-archive block") - - def test_fileobj(self): - # Test that __exit__() did not close the external file - # object. - fobj = open(tmpname, "wb") - try: - with tarfile.open(fileobj=fobj, mode="w") as tar: - raise Exception - except: - pass - self.assertFalse(fobj.closed, "external file object was closed") - self.assertTrue(tar.closed, "context manager failed") - fobj.close() - - -class LinkEmulationTest(ReadTest): - - # Test for issue #8741 regression. On platforms that do not support - # symbolic or hard links tarfile tries to extract these types of members as - # the regular files they point to. - def _test_link_extraction(self, name): - self.tar.extract(name, TEMPDIR) - data = open(os.path.join(TEMPDIR, name), "rb").read() - self.assertEqual(md5sum(data), md5_regtype) - - def test_hardlink_extraction1(self): - self._test_link_extraction("ustar/lnktype") - - def test_hardlink_extraction2(self): - self._test_link_extraction("./ustar/linktest2/lnktype") - - def test_symlink_extraction1(self): - self._test_link_extraction("ustar/symtype") - - def test_symlink_extraction2(self): - self._test_link_extraction("./ustar/linktest2/symtype") - - -class GzipMiscReadTest(MiscReadTest): - tarname = gzipname - mode = "r:gz" -class GzipUstarReadTest(UstarReadTest): - tarname = gzipname - mode = "r:gz" -class GzipStreamReadTest(StreamReadTest): - tarname = gzipname - mode = "r|gz" -class GzipWriteTest(WriteTest): - mode = "w:gz" -class GzipStreamWriteTest(StreamWriteTest): - mode = "w|gz" - - -class Bz2MiscReadTest(MiscReadTest): - tarname = bz2name - mode = "r:bz2" -class Bz2UstarReadTest(UstarReadTest): - tarname = bz2name - mode = "r:bz2" -class Bz2StreamReadTest(StreamReadTest): - tarname = bz2name - mode = "r|bz2" -class Bz2WriteTest(WriteTest): - mode = "w:bz2" -class Bz2StreamWriteTest(StreamWriteTest): - mode = "w|bz2" - -class Bz2PartialReadTest(unittest.TestCase): - # Issue5068: The _BZ2Proxy.read() method loops forever - # on an empty or partial bzipped file. - - def _test_partial_input(self, mode): - class MyStringIO(StringIO.StringIO): - hit_eof = False - def read(self, n): - if self.hit_eof: - raise AssertionError("infinite loop detected in tarfile.open()") - self.hit_eof = self.pos == self.len - return StringIO.StringIO.read(self, n) - def seek(self, *args): - self.hit_eof = False - return StringIO.StringIO.seek(self, *args) - - data = bz2.compress(tarfile.TarInfo("foo").tobuf()) - for x in range(len(data) + 1): - try: - tarfile.open(fileobj=MyStringIO(data[:x]), mode=mode) - except tarfile.ReadError: - pass # we have no interest in ReadErrors - - def test_partial_input(self): - self._test_partial_input("r") - - def test_partial_input_bz2(self): - self._test_partial_input("r:bz2") - - -def test_main(): - os.makedirs(TEMPDIR) - - tests = [ - UstarReadTest, - MiscReadTest, - StreamReadTest, - DetectReadTest, - MemberReadTest, - GNUReadTest, - PaxReadTest, - WriteTest, - StreamWriteTest, - GNUWriteTest, - PaxWriteTest, - UstarUnicodeTest, - GNUUnicodeTest, - PaxUnicodeTest, - AppendTest, - LimitsTest, - ContextManagerTest, - ] - - if hasattr(os, "link"): - tests.append(HardlinkTest) - else: - tests.append(LinkEmulationTest) - - fobj = open(tarname, "rb") - data = fobj.read() - fobj.close() - - if gzip: - # Create testtar.tar.gz and add gzip-specific tests. - tar = gzip.open(gzipname, "wb") - tar.write(data) - tar.close() - - tests += [ - GzipMiscReadTest, - GzipUstarReadTest, - GzipStreamReadTest, - GzipWriteTest, - GzipStreamWriteTest, - ] - - if bz2: - # Create testtar.tar.bz2 and add bz2-specific tests. - tar = bz2.BZ2File(bz2name, "wb") - tar.write(data) - tar.close() - - tests += [ - Bz2MiscReadTest, - Bz2UstarReadTest, - Bz2StreamReadTest, - Bz2WriteTest, - Bz2StreamWriteTest, - Bz2PartialReadTest, - ] - - try: - test_support.run_unittest(*tests) - finally: - if os.path.exists(TEMPDIR): - shutil.rmtree(TEMPDIR) - -if __name__ == "__main__": - test_main() diff --git a/src/org/python/core/PyByteArray.java b/src/org/python/core/PyByteArray.java --- a/src/org/python/core/PyByteArray.java +++ b/src/org/python/core/PyByteArray.java @@ -151,7 +151,7 @@ * * @param storage pre-initialised with desired value: the caller should not keep a reference */ - PyByteArray(byte[] storage) { + public PyByteArray(byte[] storage) { super(TYPE); setStorage(storage); } @@ -165,7 +165,7 @@ * @throws IllegalArgumentException if the range [0:size] is not within the array bounds of the * storage. */ - PyByteArray(byte[] storage, int size) { + public PyByteArray(byte[] storage, int size) { super(TYPE); setStorage(storage, size); } diff --git a/src/org/python/modules/bz2/PyBZ2Decompressor.java b/src/org/python/modules/bz2/PyBZ2Decompressor.java --- a/src/org/python/modules/bz2/PyBZ2Decompressor.java +++ b/src/org/python/modules/bz2/PyBZ2Decompressor.java @@ -8,6 +8,7 @@ import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; import org.python.core.ArgParser; import org.python.core.Py; +import org.python.core.PyByteArray; import org.python.core.PyObject; import org.python.core.PyString; import org.python.core.PyType; @@ -89,19 +90,17 @@ return Py.EmptyString; } - ByteArrayOutputStream databuf = new ByteArrayOutputStream(); + PyByteArray databuf = new PyByteArray(); int currentByte = -1; try { while ((currentByte = decompressStream.read()) != -1) { - databuf.write(currentByte); + databuf.append((byte)currentByte); } - returnData = new PyString(new String(databuf.toByteArray())); + returnData = databuf.__str__(); if (compressedData.available() > 0) { byte[] unusedbuf = new byte[compressedData.available()]; compressedData.read(unusedbuf); - - unused_data = (PyString) unused_data.__add__(new PyString( - new String(unusedbuf))); + unused_data = (PyString)unused_data.__add__((new PyByteArray(unusedbuf)).__str__()); } eofReached = true; } catch (IOException e) { -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 14 09:36:08 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 14 Dec 2014 08:36:08 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Support_a_wider_range_of_bo?= =?utf-8?q?oleans_for_=5Fcodecs=2Eutf=5F8=5Fdecode?= Message-ID: <20141214083608.92271.3894@psf.io> https://hg.python.org/jython/rev/5de518935f88 changeset: 7452:5de518935f88 user: Jim Baker date: Sun Dec 14 01:36:04 2014 -0700 summary: Support a wider range of booleans for _codecs.utf_8_decode Fixes test_univnewlines files: src/org/python/modules/_codecs.java | 4 ++++ 1 files changed, 4 insertions(+), 0 deletions(-) diff --git a/src/org/python/modules/_codecs.java b/src/org/python/modules/_codecs.java --- a/src/org/python/modules/_codecs.java +++ b/src/org/python/modules/_codecs.java @@ -238,6 +238,10 @@ return utf_8_decode(str, errors, false); } + public static PyTuple utf_8_decode(String str, String errors, PyObject final_) { + return utf_8_decode(str, errors, final_.__nonzero__()); + } + public static PyTuple utf_8_decode(String str, String errors, boolean final_) { int[] consumed = final_ ? null : new int[1]; return decode_tuple(codecs.PyUnicode_DecodeUTF8Stateful(str, errors, consumed), final_ -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 14 15:24:55 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 14 Dec 2014 14:24:55 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Do_not_run_C_implementation?= =?utf-8?q?_specific_test_in_test=5Fssl?= Message-ID: <20141214142454.18137.25131@psf.io> https://hg.python.org/jython/rev/ad0974583e48 changeset: 7453:ad0974583e48 user: Jim Baker date: Sun Dec 14 07:24:47 2014 -0700 summary: Do not run C implementation specific test in test_ssl files: Lib/_sslcerts.py | 24 - Lib/test/test_ssl.py | 1395 ++++++++++++++++++++++++++++++ 2 files changed, 1395 insertions(+), 24 deletions(-) diff --git a/Lib/_sslcerts.py b/Lib/_sslcerts.py --- a/Lib/_sslcerts.py +++ b/Lib/_sslcerts.py @@ -31,33 +31,9 @@ log = logging.getLogger("ssl") -# FIXME what happens if reloaded? Security.addProvider(BouncyCastleProvider()) -# build the necessary certificate with a CertificateFactory; this can take the pem format: -# http://docs.oracle.com/javase/7/docs/api/java/security/cert/CertificateFactory.html#generateCertificate(java.io.InputStream) - -# not certain if we can include a private key in the pem file; see -# http://stackoverflow.com/questions/7216969/getting-rsa-private-key-from-pem-base64-encoded-private-key-file - - -# helpful advice for being able to manage ca_certs outside of Java's keystore -# specifically the example ReloadableX509TrustManager -# http://jcalcote.wordpress.com/2010/06/22/managing-a-dynamic-java-trust-store/ - -# in the case of http://docs.python.org/2/library/ssl.html#ssl.CERT_REQUIRED - -# http://docs.python.org/2/library/ssl.html#ssl.CERT_NONE -# https://github.com/rackerlabs/romper/blob/master/romper/trust.py#L15 -# -# it looks like CERT_OPTIONAL simply validates certificates if -# provided, probably something in checkServerTrusted - maybe a None -# arg? need to verify as usual with a real system... :) - -# http://alesaudate.wordpress.com/2010/08/09/how-to-dynamically-select-a-certificate-alias-when-invoking-web-services/ -# is somewhat relevant for managing the keyfile, certfile - def _get_ca_certs_trust_manager(ca_certs): trust_store = KeyStore.getInstance(KeyStore.getDefaultType()) diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py new file mode 100644 --- /dev/null +++ b/Lib/test/test_ssl.py @@ -0,0 +1,1395 @@ +# Test the support for SSL and sockets + +import sys +import unittest +from test import test_support +import asyncore +import socket +import select +import time +import gc +import os +import errno +import pprint +import urllib, urlparse +import traceback +import weakref +import functools +import platform + +from BaseHTTPServer import HTTPServer +from SimpleHTTPServer import SimpleHTTPRequestHandler + +ssl = test_support.import_module("ssl") + +HOST = test_support.HOST +CERTFILE = None +SVN_PYTHON_ORG_ROOT_CERT = None + +def handle_error(prefix): + exc_format = ' '.join(traceback.format_exception(*sys.exc_info())) + if test_support.verbose: + sys.stdout.write(prefix + exc_format) + + +class BasicTests(unittest.TestCase): + + def test_sslwrap_simple(self): + # A crude test for the legacy API + try: + ssl.sslwrap_simple(socket.socket(socket.AF_INET)) + except IOError, e: + if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that + pass + else: + raise + try: + ssl.sslwrap_simple(socket.socket(socket.AF_INET)._sock) + except IOError, e: + if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that + pass + else: + raise + +# Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2 +def skip_if_broken_ubuntu_ssl(func): + if hasattr(ssl, 'PROTOCOL_SSLv2'): + # We need to access the lower-level wrapper in order to create an + # implicit SSL context without trying to connect or listen. + try: + import _ssl + except ImportError: + # The returned function won't get executed, just ignore the error + pass + @functools.wraps(func) + def f(*args, **kwargs): + try: + s = socket.socket(socket.AF_INET) + _ssl.sslwrap(s._sock, 0, None, None, + ssl.CERT_NONE, ssl.PROTOCOL_SSLv2, None, None) + except ssl.SSLError as e: + if (ssl.OPENSSL_VERSION_INFO == (0, 9, 8, 15, 15) and + platform.linux_distribution() == ('debian', 'squeeze/sid', '') + and 'Invalid SSL protocol variant specified' in str(e)): + raise unittest.SkipTest("Patched Ubuntu OpenSSL breaks behaviour") + return func(*args, **kwargs) + return f + else: + return func + + +class BasicSocketTests(unittest.TestCase): + + def test_constants(self): + #ssl.PROTOCOL_SSLv2 + ssl.PROTOCOL_SSLv23 + ssl.PROTOCOL_SSLv3 + ssl.PROTOCOL_TLSv1 + ssl.CERT_NONE + ssl.CERT_OPTIONAL + ssl.CERT_REQUIRED + + def test_random(self): + v = ssl.RAND_status() + if test_support.verbose: + sys.stdout.write("\n RAND_status is %d (%s)\n" + % (v, (v and "sufficient randomness") or + "insufficient randomness")) + self.assertRaises(TypeError, ssl.RAND_egd, 1) + self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1) + ssl.RAND_add("this is a random string", 75.0) + + @unittest.skipIf(test_support.is_jython, "Jython uses BouncyCastle") + def test_parse_cert(self): + # note that this uses an 'unofficial' function in _ssl.c, + # provided solely for this test, to exercise the certificate + # parsing code + p = ssl._ssl._test_decode_cert(CERTFILE, False) + if test_support.verbose: + sys.stdout.write("\n" + pprint.pformat(p) + "\n") + self.assertEqual(p['subject'], + ((('countryName', 'XY'),), + (('localityName', 'Castle Anthrax'),), + (('organizationName', 'Python Software Foundation'),), + (('commonName', 'localhost'),)) + ) + self.assertEqual(p['subjectAltName'], (('DNS', 'localhost'),)) + # Issue #13034: the subjectAltName in some certificates + # (notably projects.developer.nokia.com:443) wasn't parsed + p = ssl._ssl._test_decode_cert(NOKIACERT) + if test_support.verbose: + sys.stdout.write("\n" + pprint.pformat(p) + "\n") + self.assertEqual(p['subjectAltName'], + (('DNS', 'projects.developer.nokia.com'), + ('DNS', 'projects.forum.nokia.com')) + ) + + def test_DER_to_PEM(self): + with open(SVN_PYTHON_ORG_ROOT_CERT, 'r') as f: + pem = f.read() + d1 = ssl.PEM_cert_to_DER_cert(pem) + p2 = ssl.DER_cert_to_PEM_cert(d1) + d2 = ssl.PEM_cert_to_DER_cert(p2) + self.assertEqual(d1, d2) + if not p2.startswith(ssl.PEM_HEADER + '\n'): + self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2) + if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'): + self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2) + + def test_openssl_version(self): + n = ssl.OPENSSL_VERSION_NUMBER + t = ssl.OPENSSL_VERSION_INFO + s = ssl.OPENSSL_VERSION + self.assertIsInstance(n, (int, long)) + self.assertIsInstance(t, tuple) + self.assertIsInstance(s, str) + # Some sanity checks follow + # >= 0.9 + self.assertGreaterEqual(n, 0x900000) + # < 2.0 + self.assertLess(n, 0x20000000) + major, minor, fix, patch, status = t + self.assertGreaterEqual(major, 0) + self.assertLess(major, 2) + self.assertGreaterEqual(minor, 0) + self.assertLess(minor, 256) + self.assertGreaterEqual(fix, 0) + self.assertLess(fix, 256) + self.assertGreaterEqual(patch, 0) + self.assertLessEqual(patch, 26) + self.assertGreaterEqual(status, 0) + self.assertLessEqual(status, 15) + # Version string as returned by OpenSSL, the format might change + self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)), + (s, t)) + + def test_ciphers(self): + if not test_support.is_resource_enabled('network'): + return + remote = ("svn.python.org", 443) + with test_support.transient_internet(remote[0]): + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_NONE, ciphers="ALL") + s.connect(remote) + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") + s.connect(remote) + # Error checking occurs when connecting, because the SSL context + # isn't created before. + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx") + with self.assertRaisesRegexp(ssl.SSLError, "No cipher can be selected"): + s.connect(remote) + + @test_support.cpython_only + def test_refcycle(self): + # Issue #7943: an SSL object doesn't create reference cycles with + # itself. + s = socket.socket(socket.AF_INET) + ss = ssl.wrap_socket(s) + wr = weakref.ref(ss) + del ss + self.assertEqual(wr(), None) + + def test_wrapped_unconnected(self): + # The _delegate_methods in socket.py are correctly delegated to by an + # unconnected SSLSocket, so they will raise a socket.error rather than + # something unexpected like TypeError. + s = socket.socket(socket.AF_INET) + ss = ssl.wrap_socket(s) + self.assertRaises(socket.error, ss.recv, 1) + self.assertRaises(socket.error, ss.recv_into, bytearray(b'x')) + self.assertRaises(socket.error, ss.recvfrom, 1) + self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1) + self.assertRaises(socket.error, ss.send, b'x') + self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0)) + + +class NetworkedTests(unittest.TestCase): + + def test_connect(self): + with test_support.transient_internet("svn.python.org"): + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_NONE) + s.connect(("svn.python.org", 443)) + c = s.getpeercert() + if c: + self.fail("Peer cert %s shouldn't be here!") + s.close() + + # this should fail because we have no verification certs + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_REQUIRED) + try: + s.connect(("svn.python.org", 443)) + except ssl.SSLError: + pass + finally: + s.close() + + # this should succeed because we specify the root cert + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=SVN_PYTHON_ORG_ROOT_CERT) + try: + s.connect(("svn.python.org", 443)) + finally: + s.close() + + def test_connect_ex(self): + # Issue #11326: check connect_ex() implementation + with test_support.transient_internet("svn.python.org"): + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=SVN_PYTHON_ORG_ROOT_CERT) + try: + self.assertEqual(0, s.connect_ex(("svn.python.org", 443))) + self.assertTrue(s.getpeercert()) + finally: + s.close() + + def test_non_blocking_connect_ex(self): + # Issue #11326: non-blocking connect_ex() should allow handshake + # to proceed after the socket gets ready. + with test_support.transient_internet("svn.python.org"): + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=SVN_PYTHON_ORG_ROOT_CERT, + do_handshake_on_connect=False) + try: + s.setblocking(False) + rc = s.connect_ex(('svn.python.org', 443)) + # EWOULDBLOCK under Windows, EINPROGRESS elsewhere + self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK)) + # Wait for connect to finish + select.select([], [s], [], 5.0) + # Non-blocking handshake + while True: + try: + s.do_handshake() + break + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + select.select([s], [], [], 5.0) + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + select.select([], [s], [], 5.0) + else: + raise + # SSL established + self.assertTrue(s.getpeercert()) + finally: + s.close() + + def test_timeout_connect_ex(self): + # Issue #12065: on a timeout, connect_ex() should return the original + # errno (mimicking the behaviour of non-SSL sockets). + with test_support.transient_internet("svn.python.org"): + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=SVN_PYTHON_ORG_ROOT_CERT, + do_handshake_on_connect=False) + try: + s.settimeout(0.0000001) + rc = s.connect_ex(('svn.python.org', 443)) + if rc == 0: + self.skipTest("svn.python.org responded too quickly") + self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK)) + finally: + s.close() + + def test_connect_ex_error(self): + with test_support.transient_internet("svn.python.org"): + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=SVN_PYTHON_ORG_ROOT_CERT) + try: + self.assertEqual(errno.ECONNREFUSED, + s.connect_ex(("svn.python.org", 444))) + finally: + s.close() + + @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows") + def test_makefile_close(self): + # Issue #5238: creating a file-like object with makefile() shouldn't + # delay closing the underlying "real socket" (here tested with its + # file descriptor, hence skipping the test under Windows). + with test_support.transient_internet("svn.python.org"): + ss = ssl.wrap_socket(socket.socket(socket.AF_INET)) + ss.connect(("svn.python.org", 443)) + fd = ss.fileno() + f = ss.makefile() + f.close() + # The fd is still open + os.read(fd, 0) + # Closing the SSL socket should close the fd too + ss.close() + gc.collect() + with self.assertRaises(OSError) as e: + os.read(fd, 0) + self.assertEqual(e.exception.errno, errno.EBADF) + + def test_non_blocking_handshake(self): + with test_support.transient_internet("svn.python.org"): + s = socket.socket(socket.AF_INET) + s.connect(("svn.python.org", 443)) + s.setblocking(False) + s = ssl.wrap_socket(s, + cert_reqs=ssl.CERT_NONE, + do_handshake_on_connect=False) + count = 0 + while True: + try: + count += 1 + s.do_handshake() + break + except ssl.SSLError, err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + select.select([s], [], []) + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + select.select([], [s], []) + else: + raise + s.close() + if test_support.verbose: + sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) + + def test_get_server_certificate(self): + with test_support.transient_internet("svn.python.org"): + pem = ssl.get_server_certificate(("svn.python.org", 443)) + if not pem: + self.fail("No server certificate on svn.python.org:443!") + + try: + pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE) + except ssl.SSLError: + #should fail + pass + else: + self.fail("Got server certificate %s for svn.python.org!" % pem) + + pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) + if not pem: + self.fail("No server certificate on svn.python.org:443!") + if test_support.verbose: + sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) + + def test_algorithms(self): + # Issue #8484: all algorithms should be available when verifying a + # certificate. + # SHA256 was added in OpenSSL 0.9.8 + if ssl.OPENSSL_VERSION_INFO < (0, 9, 8, 0, 15): + self.skipTest("SHA256 not available on %r" % ssl.OPENSSL_VERSION) + self.skipTest("remote host needs SNI, only available on Python 3.2+") + # NOTE: https://sha2.hboeck.de is another possible test host + remote = ("sha256.tbs-internet.com", 443) + sha256_cert = os.path.join(os.path.dirname(__file__), "sha256.pem") + with test_support.transient_internet("sha256.tbs-internet.com"): + s = ssl.wrap_socket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=sha256_cert,) + try: + s.connect(remote) + if test_support.verbose: + sys.stdout.write("\nCipher with %r is %r\n" % + (remote, s.cipher())) + sys.stdout.write("Certificate is:\n%s\n" % + pprint.pformat(s.getpeercert())) + finally: + s.close() + + +try: + import threading +except ImportError: + _have_threads = False +else: + _have_threads = True + + class ThreadedEchoServer(threading.Thread): + + class ConnectionHandler(threading.Thread): + + """A mildly complicated class, because we want it to work both + with and without the SSL wrapper around the socket connection, so + that we can test the STARTTLS functionality.""" + + def __init__(self, server, connsock): + self.server = server + self.running = False + self.sock = connsock + self.sock.setblocking(1) + self.sslconn = None + threading.Thread.__init__(self) + self.daemon = True + + def show_conn_details(self): + if self.server.certreqs == ssl.CERT_REQUIRED: + cert = self.sslconn.getpeercert() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") + cert_binary = self.sslconn.getpeercert(True) + if test_support.verbose and self.server.chatty: + sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") + cipher = self.sslconn.cipher() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") + + def wrap_conn(self): + try: + self.sslconn = ssl.wrap_socket(self.sock, server_side=True, + certfile=self.server.certificate, + ssl_version=self.server.protocol, + ca_certs=self.server.cacerts, + cert_reqs=self.server.certreqs, + ciphers=self.server.ciphers) + except ssl.SSLError as e: + # XXX Various errors can have happened here, for example + # a mismatching protocol version, an invalid certificate, + # or a low-level bug. This should be made more discriminating. + self.server.conn_errors.append(e) + if self.server.chatty: + handle_error("\n server: bad connection attempt from " + + str(self.sock.getpeername()) + ":\n") + self.close() + self.running = False + self.server.stop() + return False + else: + return True + + def read(self): + if self.sslconn: + return self.sslconn.read() + else: + return self.sock.recv(1024) + + def write(self, bytes): + if self.sslconn: + return self.sslconn.write(bytes) + else: + return self.sock.send(bytes) + + def close(self): + if self.sslconn: + self.sslconn.close() + else: + self.sock._sock.close() + + def run(self): + self.running = True + if not self.server.starttls_server: + if isinstance(self.sock, ssl.SSLSocket): + self.sslconn = self.sock + elif not self.wrap_conn(): + return + self.show_conn_details() + while self.running: + try: + msg = self.read() + if not msg: + # eof, so quit this handler + self.running = False + self.close() + elif msg.strip() == 'over': + if test_support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: client closed connection\n") + self.close() + return + elif self.server.starttls_server and msg.strip() == 'STARTTLS': + if test_support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read STARTTLS from client, sending OK...\n") + self.write("OK\n") + if not self.wrap_conn(): + return + elif self.server.starttls_server and self.sslconn and msg.strip() == 'ENDTLS': + if test_support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read ENDTLS from client, sending OK...\n") + self.write("OK\n") + self.sslconn.unwrap() + self.sslconn = None + if test_support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: connection is now unencrypted...\n") + else: + if (test_support.verbose and + self.server.connectionchatty): + ctype = (self.sslconn and "encrypted") or "unencrypted" + sys.stdout.write(" server: read %s (%s), sending back %s (%s)...\n" + % (repr(msg), ctype, repr(msg.lower()), ctype)) + self.write(msg.lower()) + except ssl.SSLError: + if self.server.chatty: + handle_error("Test server failure:\n") + self.close() + self.running = False + # normally, we'd just stop here, but for the test + # harness, we want to stop the server + self.server.stop() + + def __init__(self, certificate, ssl_version=None, + certreqs=None, cacerts=None, + chatty=True, connectionchatty=False, starttls_server=False, + wrap_accepting_socket=False, ciphers=None): + + if ssl_version is None: + ssl_version = ssl.PROTOCOL_TLSv1 + if certreqs is None: + certreqs = ssl.CERT_NONE + self.certificate = certificate + self.protocol = ssl_version + self.certreqs = certreqs + self.cacerts = cacerts + self.ciphers = ciphers + self.chatty = chatty + self.connectionchatty = connectionchatty + self.starttls_server = starttls_server + self.sock = socket.socket() + self.flag = None + if wrap_accepting_socket: + self.sock = ssl.wrap_socket(self.sock, server_side=True, + certfile=self.certificate, + cert_reqs = self.certreqs, + ca_certs = self.cacerts, + ssl_version = self.protocol, + ciphers = self.ciphers) + if test_support.verbose and self.chatty: + sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock)) + self.port = test_support.bind_port(self.sock) + self.active = False + self.conn_errors = [] + threading.Thread.__init__(self) + self.daemon = True + + def __enter__(self): + self.start(threading.Event()) + self.flag.wait() + return self + + def __exit__(self, *args): + self.stop() + self.join() + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.sock.settimeout(0.05) + self.sock.listen(5) + self.active = True + if self.flag: + # signal an event + self.flag.set() + while self.active: + try: + newconn, connaddr = self.sock.accept() + if test_support.verbose and self.chatty: + sys.stdout.write(' server: new connection from ' + + str(connaddr) + '\n') + handler = self.ConnectionHandler(self, newconn) + handler.start() + handler.join() + except socket.timeout: + pass + except KeyboardInterrupt: + self.stop() + self.sock.close() + + def stop(self): + self.active = False + + class AsyncoreEchoServer(threading.Thread): + + class EchoServer(asyncore.dispatcher): + + class ConnectionHandler(asyncore.dispatcher_with_send): + + def __init__(self, conn, certfile): + asyncore.dispatcher_with_send.__init__(self, conn) + self.socket = ssl.wrap_socket(conn, server_side=True, + certfile=certfile, + do_handshake_on_connect=False) + self._ssl_accepting = True + + def readable(self): + if isinstance(self.socket, ssl.SSLSocket): + while self.socket.pending() > 0: + self.handle_read_event() + return True + + def _do_ssl_handshake(self): + try: + self.socket.do_handshake() + except ssl.SSLError, err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + return + elif err.args[0] == ssl.SSL_ERROR_EOF: + return self.handle_close() + raise + except socket.error, err: + if err.args[0] == errno.ECONNABORTED: + return self.handle_close() + else: + self._ssl_accepting = False + + def handle_read(self): + if self._ssl_accepting: + self._do_ssl_handshake() + else: + data = self.recv(1024) + if data and data.strip() != 'over': + self.send(data.lower()) + + def handle_close(self): + self.close() + if test_support.verbose: + sys.stdout.write(" server: closed connection %s\n" % self.socket) + + def handle_error(self): + raise + + def __init__(self, certfile): + self.certfile = certfile + asyncore.dispatcher.__init__(self) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.port = test_support.bind_port(self.socket) + self.listen(5) + + def handle_accept(self): + sock_obj, addr = self.accept() + if test_support.verbose: + sys.stdout.write(" server: new connection from %s:%s\n" %addr) + self.ConnectionHandler(sock_obj, self.certfile) + + def handle_error(self): + raise + + def __init__(self, certfile): + self.flag = None + self.active = False + self.server = self.EchoServer(certfile) + self.port = self.server.port + threading.Thread.__init__(self) + self.daemon = True + + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) + + def __enter__(self): + self.start(threading.Event()) + self.flag.wait() + return self + + def __exit__(self, *args): + if test_support.verbose: + sys.stdout.write(" cleanup: stopping server.\n") + self.stop() + if test_support.verbose: + sys.stdout.write(" cleanup: joining server thread.\n") + self.join() + if test_support.verbose: + sys.stdout.write(" cleanup: successfully joined.\n") + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.active = True + if self.flag: + self.flag.set() + while self.active: + asyncore.loop(0.05) + + def stop(self): + self.active = False + self.server.close() + + class SocketServerHTTPSServer(threading.Thread): + + class HTTPSServer(HTTPServer): + + def __init__(self, server_address, RequestHandlerClass, certfile): + HTTPServer.__init__(self, server_address, RequestHandlerClass) + # we assume the certfile contains both private key and certificate + self.certfile = certfile + self.allow_reuse_address = True + + def __str__(self): + return ('<%s %s:%s>' % + (self.__class__.__name__, + self.server_name, + self.server_port)) + + def get_request(self): + # override this to wrap socket with SSL + sock, addr = self.socket.accept() + sslconn = ssl.wrap_socket(sock, server_side=True, + certfile=self.certfile) + return sslconn, addr + + class RootedHTTPRequestHandler(SimpleHTTPRequestHandler): + # need to override translate_path to get a known root, + # instead of using os.curdir, since the test could be + # run from anywhere + + server_version = "TestHTTPS/1.0" + + root = None + + def translate_path(self, path): + """Translate a /-separated PATH to the local filename syntax. + + Components that mean special things to the local file system + (e.g. drive or directory names) are ignored. (XXX They should + probably be diagnosed.) + + """ + # abandon query parameters + path = urlparse.urlparse(path)[2] + path = os.path.normpath(urllib.unquote(path)) + words = path.split('/') + words = filter(None, words) + path = self.root + for word in words: + drive, word = os.path.splitdrive(word) + head, word = os.path.split(word) + if word in self.root: continue + path = os.path.join(path, word) + return path + + def log_message(self, format, *args): + + # we override this to suppress logging unless "verbose" + + if test_support.verbose: + sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" % + (self.server.server_address, + self.server.server_port, + self.request.cipher(), + self.log_date_time_string(), + format%args)) + + + def __init__(self, certfile): + self.flag = None + self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0] + self.server = self.HTTPSServer( + (HOST, 0), self.RootedHTTPRequestHandler, certfile) + self.port = self.server.server_port + threading.Thread.__init__(self) + self.daemon = True + + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + if self.flag: + self.flag.set() + self.server.serve_forever(0.05) + + def stop(self): + self.server.shutdown() + + + def bad_cert_test(certfile): + """ + Launch a server with CERT_REQUIRED, and check that trying to + connect to it with the given client certificate fails. + """ + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_REQUIRED, + cacerts=CERTFILE, chatty=False) + with server: + try: + s = ssl.wrap_socket(socket.socket(), + certfile=certfile, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + except ssl.SSLError, x: + if test_support.verbose: + sys.stdout.write("\nSSLError is %s\n" % x[1]) + except socket.error, x: + if test_support.verbose: + sys.stdout.write("\nsocket.error is %s\n" % x[1]) + else: + raise AssertionError("Use of invalid cert should have failed!") + + def server_params_test(certfile, protocol, certreqs, cacertsfile, + client_certfile, client_protocol=None, indata="FOO\n", + ciphers=None, chatty=True, connectionchatty=False, + wrap_accepting_socket=False): + """ + Launch a server, connect a client to it and try various reads + and writes. + """ + server = ThreadedEchoServer(certfile, + certreqs=certreqs, + ssl_version=protocol, + cacerts=cacertsfile, + ciphers=ciphers, + chatty=chatty, + connectionchatty=connectionchatty, + wrap_accepting_socket=wrap_accepting_socket) + with server: + # try to connect + if client_protocol is None: + client_protocol = protocol + s = ssl.wrap_socket(socket.socket(), + certfile=client_certfile, + ca_certs=cacertsfile, + ciphers=ciphers, + cert_reqs=certreqs, + ssl_version=client_protocol) + s.connect((HOST, server.port)) + for arg in [indata, bytearray(indata), memoryview(indata)]: + if connectionchatty: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(arg))) + s.write(arg) + outdata = s.read() + if connectionchatty: + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + raise AssertionError( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata),20)], len(outdata), + indata[:min(len(indata),20)].lower(), len(indata))) + s.write("over\n") + if connectionchatty: + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + + def try_protocol_combo(server_protocol, + client_protocol, + expect_success, + certsreqs=None): + if certsreqs is None: + certsreqs = ssl.CERT_NONE + certtype = { + ssl.CERT_NONE: "CERT_NONE", + ssl.CERT_OPTIONAL: "CERT_OPTIONAL", + ssl.CERT_REQUIRED: "CERT_REQUIRED", + }[certsreqs] + if test_support.verbose: + formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n" + sys.stdout.write(formatstr % + (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol), + certtype)) + try: + # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client + # will send an SSLv3 hello (rather than SSLv2) starting from + # OpenSSL 1.0.0 (see issue #8322). + server_params_test(CERTFILE, server_protocol, certsreqs, + CERTFILE, CERTFILE, client_protocol, + ciphers="ALL", chatty=False) + # Protocol mismatch can result in either an SSLError, or a + # "Connection reset by peer" error. + except ssl.SSLError: + if expect_success: + raise + except socket.error as e: + if expect_success or e.errno != errno.ECONNRESET: + raise + else: + if not expect_success: + raise AssertionError( + "Client protocol %s succeeded with server protocol %s!" + % (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol))) + + + class ThreadedTests(unittest.TestCase): + + def test_rude_shutdown(self): + """A brutal shutdown of an SSL server should raise an IOError + in the client when attempting handshake. + """ + listener_ready = threading.Event() + listener_gone = threading.Event() + + s = socket.socket() + port = test_support.bind_port(s, HOST) + + # `listener` runs in a thread. It sits in an accept() until + # the main thread connects. Then it rudely closes the socket, + # and sets Event `listener_gone` to let the main thread know + # the socket is gone. + def listener(): + s.listen(5) + listener_ready.set() + s.accept() + s.close() + listener_gone.set() + + def connector(): + listener_ready.wait() + c = socket.socket() + c.connect((HOST, port)) + listener_gone.wait() + try: + ssl_sock = ssl.wrap_socket(c) + except IOError: + pass + else: + self.fail('connecting to closed SSL socket should have failed') + + t = threading.Thread(target=listener) + t.start() + try: + connector() + finally: + t.join() + + @skip_if_broken_ubuntu_ssl + def test_echo(self): + """Basic test of an SSL client connecting to a server""" + if test_support.verbose: + sys.stdout.write("\n") + server_params_test(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE, + CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1, + chatty=True, connectionchatty=True) + + def test_getpeercert(self): + if test_support.verbose: + sys.stdout.write("\n") + s2 = socket.socket() + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_SSLv23, + cacerts=CERTFILE, + chatty=False) + with server: + s = ssl.wrap_socket(socket.socket(), + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_REQUIRED, + ssl_version=ssl.PROTOCOL_SSLv23) + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + cipher = s.cipher() + if test_support.verbose: + sys.stdout.write(pprint.pformat(cert) + '\n') + sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') + if 'subject' not in cert: + self.fail("No subject field in certificate: %s." % + pprint.pformat(cert)) + if ((('organizationName', 'Python Software Foundation'),) + not in cert['subject']): + self.fail( + "Missing or invalid 'organizationName' field in certificate subject; " + "should be 'Python Software Foundation'.") + s.close() + + def test_empty_cert(self): + """Connecting with an empty cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "nullcert.pem")) + def test_malformed_cert(self): + """Connecting with a badly formatted certificate (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "badcert.pem")) + def test_nonexisting_cert(self): + """Connecting with a non-existing cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "wrongcert.pem")) + def test_malformed_key(self): + """Connecting with a badly formatted key (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "badkey.pem")) + + @skip_if_broken_ubuntu_ssl + def test_protocol_sslv2(self): + """Connecting to an SSLv2 server with various client options""" + if test_support.verbose: + sys.stdout.write("\n") + if not hasattr(ssl, 'PROTOCOL_SSLv2'): + self.skipTest("PROTOCOL_SSLv2 needed") + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) + + @skip_if_broken_ubuntu_ssl + def test_protocol_sslv23(self): + """Connecting to an SSLv23 server with various client options""" + if test_support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + + @skip_if_broken_ubuntu_ssl + def test_protocol_sslv3(self): + """Connecting to an SSLv3 server with various client options""" + if test_support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) + + @skip_if_broken_ubuntu_ssl + def test_protocol_tlsv1(self): + """Connecting to a TLSv1 server with various client options""" + if test_support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) + + def test_starttls(self): + """Switching from clear text to encrypted and back again.""" + msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6") + + server = ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_TLSv1, + starttls_server=True, + chatty=True, + connectionchatty=True) + wrapped = False + with server: + s = socket.socket() + s.setblocking(1) + s.connect((HOST, server.port)) + if test_support.verbose: + sys.stdout.write("\n") + for indata in msgs: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % repr(indata)) + if wrapped: + conn.write(indata) + outdata = conn.read() + else: + s.send(indata) + outdata = s.recv(1024) + if (indata == "STARTTLS" and + outdata.strip().lower().startswith("ok")): + # STARTTLS ok, switch to secure mode + if test_support.verbose: + sys.stdout.write( + " client: read %s from server, starting TLS...\n" + % repr(outdata)) + conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) + wrapped = True + elif (indata == "ENDTLS" and + outdata.strip().lower().startswith("ok")): + # ENDTLS ok, switch back to clear text + if test_support.verbose: + sys.stdout.write( + " client: read %s from server, ending TLS...\n" + % repr(outdata)) + s = conn.unwrap() + wrapped = False + else: + if test_support.verbose: + sys.stdout.write( + " client: read %s from server\n" % repr(outdata)) + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + if wrapped: + conn.write("over\n") + else: + s.send("over\n") + s.close() + + def test_socketserver(self): + """Using a SocketServer to create and manage SSL connections.""" + server = SocketServerHTTPSServer(CERTFILE) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + if test_support.verbose: + sys.stdout.write('\n') + with open(CERTFILE, 'rb') as f: + d1 = f.read() + d2 = '' + # now fetch the same data from the HTTPS server + url = 'https://127.0.0.1:%d/%s' % ( + server.port, os.path.split(CERTFILE)[1]) + with test_support.check_py3k_warnings(): + f = urllib.urlopen(url) + dlen = f.info().getheader("content-length") + if dlen and (int(dlen) > 0): + d2 = f.read(int(dlen)) + if test_support.verbose: + sys.stdout.write( + " client: read %d bytes from remote server '%s'\n" + % (len(d2), server)) + f.close() + self.assertEqual(d1, d2) + finally: + server.stop() + server.join() + + def test_wrapped_accept(self): + """Check the accept() method on SSL sockets.""" + if test_support.verbose: + sys.stdout.write("\n") + server_params_test(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED, + CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23, + chatty=True, connectionchatty=True, + wrap_accepting_socket=True) + + def test_asyncore_server(self): + """Check the example asyncore integration.""" + indata = "TEST MESSAGE of mixed case\n" + + if test_support.verbose: + sys.stdout.write("\n") + server = AsyncoreEchoServer(CERTFILE) + with server: + s = ssl.wrap_socket(socket.socket()) + s.connect(('127.0.0.1', server.port)) + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(indata))) + s.write(indata) + outdata = s.read() + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + self.fail( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata),20)], len(outdata), + indata[:min(len(indata),20)].lower(), len(indata))) + s.write("over\n") + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + + def test_recv_send(self): + """Test recv(), send() and friends.""" + if test_support.verbose: + sys.stdout.write("\n") + + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = ssl.wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + # helper methods for standardising recv* method signatures + def _recv_into(): + b = bytearray("\0"*100) + count = s.recv_into(b) + return b[:count] + + def _recvfrom_into(): + b = bytearray("\0"*100) + count, addr = s.recvfrom_into(b) + return b[:count] + + # (name, method, whether to expect success, *args) + send_methods = [ + ('send', s.send, True, []), + ('sendto', s.sendto, False, ["some.address"]), + ('sendall', s.sendall, True, []), + ] + recv_methods = [ + ('recv', s.recv, True, []), + ('recvfrom', s.recvfrom, False, ["some.address"]), + ('recv_into', _recv_into, True, []), + ('recvfrom_into', _recvfrom_into, False, []), + ] + data_prefix = u"PREFIX_" + + for meth_name, send_meth, expect_success, args in send_methods: + indata = data_prefix + meth_name + try: + send_meth(indata.encode('ASCII', 'strict'), *args) + outdata = s.read() + outdata = outdata.decode('ASCII', 'strict') + if outdata != indata.lower(): + self.fail( + "While sending with <<%s>> bad data " + "<<%r>> (%d) received; " + "expected <<%r>> (%d)\n" % ( + meth_name, outdata[:20], len(outdata), + indata[:20], len(indata) + ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to send with method <<%s>>; " + "expected to succeed.\n" % (meth_name,) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<%s>> failed with unexpected " + "exception message: %s\n" % ( + meth_name, e + ) + ) + + for meth_name, recv_meth, expect_success, args in recv_methods: + indata = data_prefix + meth_name + try: + s.send(indata.encode('ASCII', 'strict')) + outdata = recv_meth(*args) + outdata = outdata.decode('ASCII', 'strict') + if outdata != indata.lower(): + self.fail( + "While receiving with <<%s>> bad data " + "<<%r>> (%d) received; " + "expected <<%r>> (%d)\n" % ( + meth_name, outdata[:20], len(outdata), + indata[:20], len(indata) + ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to receive with method <<%s>>; " + "expected to succeed.\n" % (meth_name,) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<%s>> failed with unexpected " + "exception message: %s\n" % ( + meth_name, e + ) + ) + # consume data + s.read() + + s.write("over\n".encode("ASCII", "strict")) + s.close() + + def test_handshake_timeout(self): + # Issue #5103: SSL handshake must respect the socket timeout + server = socket.socket(socket.AF_INET) + host = "127.0.0.1" + port = test_support.bind_port(server) + started = threading.Event() + finish = False + + def serve(): + server.listen(5) + started.set() + conns = [] + while not finish: + r, w, e = select.select([server], [], [], 0.1) + if server in r: + # Let the socket hang around rather than having + # it closed by garbage collection. + conns.append(server.accept()[0]) + + t = threading.Thread(target=serve) + t.start() + started.wait() + + try: + try: + c = socket.socket(socket.AF_INET) + c.settimeout(0.2) + c.connect((host, port)) + # Will attempt handshake and time out + self.assertRaisesRegexp(ssl.SSLError, "timed out", + ssl.wrap_socket, c) + finally: + c.close() + try: + c = socket.socket(socket.AF_INET) + c.settimeout(0.2) + c = ssl.wrap_socket(c) + # Will attempt handshake and time out + self.assertRaisesRegexp(ssl.SSLError, "timed out", + c.connect, (host, port)) + finally: + c.close() + finally: + finish = True + t.join() + server.close() + + def test_default_ciphers(self): + with ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_SSLv23, + chatty=False) as server: + sock = socket.socket() + try: + # Force a set of weak ciphers on our client socket + try: + s = ssl.wrap_socket(sock, + ssl_version=ssl.PROTOCOL_SSLv23, + ciphers="DES") + except ssl.SSLError: + self.skipTest("no DES cipher available") + with self.assertRaises((OSError, ssl.SSLError)): + s.connect((HOST, server.port)) + finally: + sock.close() + self.assertIn("no shared cipher", str(server.conn_errors[0])) + + +def test_main(verbose=False): + global CERTFILE, SVN_PYTHON_ORG_ROOT_CERT, NOKIACERT + CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, + "keycert.pem") + SVN_PYTHON_ORG_ROOT_CERT = os.path.join( + os.path.dirname(__file__) or os.curdir, + "https_svn_python_org_root.pem") + NOKIACERT = os.path.join(os.path.dirname(__file__) or os.curdir, + "nokia.pem") + + if (not os.path.exists(CERTFILE) or + not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT) or + not os.path.exists(NOKIACERT)): + raise test_support.TestFailed("Can't read certificate files!") + + tests = [BasicTests, BasicSocketTests] + + if test_support.is_resource_enabled('network'): + tests.append(NetworkedTests) + + if _have_threads: + thread_info = test_support.threading_setup() + if thread_info and test_support.is_resource_enabled('network'): + tests.append(ThreadedTests) + + try: + test_support.run_unittest(*tests) + finally: + if _have_threads: + test_support.threading_cleanup(*thread_info) + +if __name__ == "__main__": + test_main() -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 14 16:11:41 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 14 Dec 2014 15:11:41 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Do_not_test_FileIO=28=27/de?= =?utf-8?b?di90dHknLCAndycpLnNlZWthYmxlKCkgaW4gdGVzdF9maWxlaW8=?= Message-ID: <20141214151140.28210.49589@psf.io> https://hg.python.org/jython/rev/625cdd97ad6d changeset: 7454:625cdd97ad6d user: Jim Baker date: Sun Dec 14 08:11:34 2014 -0700 summary: Do not test FileIO('/dev/tty', 'w').seekable() in test_fileio On OSX, _io.FileIO("/dev/tty", "w").isatty() is False On Ubuntu, _io.FileIO("/dev/tty", "w").isatty() throws IOError: Illegal seek Much like we see on other platforms, we cannot reliably determine it is not seekable (or perhaps special in general). Let's track in the related bug (http://bugs.jython.org/issue1945), instead of test_fileio, which is more about the new IO support in _io.FileIO. files: Lib/test/test_fileio.py | 14 +++++++++++--- 1 files changed, 11 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_fileio.py b/Lib/test/test_fileio.py --- a/Lib/test/test_fileio.py +++ b/Lib/test/test_fileio.py @@ -11,7 +11,7 @@ from functools import wraps from test.test_support import (TESTFN, check_warnings, run_unittest, - make_bad_fd, is_jython) + make_bad_fd, is_jython, gc_collect) from test.test_support import py3k_bytes as bytes from test.script_helper import run_python @@ -34,7 +34,6 @@ self.f.close() os.remove(TESTFN) - @unittest.skipIf(is_jython, "FIXME: not working in Jython") def testWeakRefs(self): # verify weak references p = proxy(self.f) @@ -42,6 +41,7 @@ self.assertEqual(self.f.tell(), p.tell()) self.f.close() self.f = None + gc_collect() self.assertRaises(ReferenceError, getattr, p, 'tell') def testSeekTell(self): @@ -294,7 +294,15 @@ self.assertEqual(f.isatty(), False) f.close() - if sys.platform != "win32": + # Jython specific issues: + # On OSX, FileIO("/dev/tty", "w").isatty() is False + # On Ubuntu, FileIO("/dev/tty", "w").isatty() throws IOError: Illegal seek + # + # Much like we see on other platforms, we cannot reliably + # determine it is not seekable (or special). + # + # Related bug: http://bugs.jython.org/issue1945 + if sys.platform != "win32" and not is_jython: try: f = self.f = _FileIO("/dev/tty", "a") except EnvironmentError: -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 14 16:28:25 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 14 Dec 2014 15:28:25 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Skip_tests_of_=5Fsocket=2E?= =?utf-8?q?=5Fget=5Fjsockaddr_due_to_OS_differences?= Message-ID: <20141214152825.81001.44008@psf.io> https://hg.python.org/jython/rev/75bc9d636bdc changeset: 7455:75bc9d636bdc user: Jim Baker date: Sun Dec 14 08:28:22 2014 -0700 summary: Skip tests of _socket._get_jsockaddr due to OS differences files: Lib/test/test_socket.py | 8 ++++++++ 1 files changed, 8 insertions(+), 0 deletions(-) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -2071,6 +2071,14 @@ result = socket.getnameinfo(address, flags) self.failUnlessEqual(result[0], expected) + +# TODO: consider re-enabling this set of tests, but for now +# this set reliably does *not* work on Ubuntu (but does on +# OSX). However the underlying internal method _get_jsockaddr +# is exercised by nearly every socket usage, along with the +# corresponding tests. + + at unittest.skipIf(test_support.is_jython, "Skip internal tests for address lookup due to underlying OS issues") class TestJython_get_jsockaddr(unittest.TestCase): "These tests are specific to jython: they test a key internal routine" -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 14 16:51:02 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 14 Dec 2014 15:51:02 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Fix_test_confused_by_loggin?= =?utf-8?q?g_output_due_to_socket_being_closed?= Message-ID: <20141214155101.50117.67668@psf.io> https://hg.python.org/jython/rev/4149ebcc9ffb changeset: 7456:4149ebcc9ffb user: Jim Baker date: Sun Dec 14 08:50:57 2014 -0700 summary: Fix test confused by logging output due to socket being closed test_partial_post causes a close error (as might be expected) in the socket server, apparently because the timing is different between CPython and Jython. So ignore so that the default SocketServer.handle_error logging does not cause issues in unexpected text output in the overall regrtest. files: Lib/test/test_xmlrpc.py | 14 +++++++++++++- 1 files changed, 13 insertions(+), 1 deletions(-) diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py --- a/Lib/test/test_xmlrpc.py +++ b/Lib/test/test_xmlrpc.py @@ -304,6 +304,15 @@ s.setblocking(True) return s, port + def handle_error(self, request, client_address): + # test_partial_post causes a close error (as might be + # expected), apparently because the timing is different + # between CPython and Jython. So ignore so that the + # default SocketServer.handle_error logging does not cause + # issues in unexpected text output in the overall + # regrtest. + pass + if not requestHandler: requestHandler = SimpleXMLRPCServer.SimpleXMLRPCRequestHandler serv = MyXMLRPCServer(("localhost", 0), requestHandler, @@ -605,7 +614,10 @@ # Check that a partial POST doesn't make the server loop: issue #14001. conn = httplib.HTTPConnection(ADDR, PORT) conn.request('POST', '/RPC2 HTTP/1.0\r\nContent-Length: 100\r\n\r\nbye') - conn.close() + try: + conn.close() + except Exception, e: + print "Got this exception", type(e), e class MultiPathServerTestCase(BaseServerTestCase): threadFunc = staticmethod(http_multi_server) -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 14 17:12:55 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 14 Dec 2014 16:12:55 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_len_of_WeakSet_objs_is_even?= =?utf-8?q?tually_consistent?= Message-ID: <20141214161254.18151.94026@psf.io> https://hg.python.org/jython/rev/5a2fd25b4bc1 changeset: 7457:5a2fd25b4bc1 user: Jim Baker date: Sun Dec 14 09:12:51 2014 -0700 summary: len of WeakSet objs is eventually consistent This really only matters in tests which assume determinism. Fix that assumption in TestWeakSet.test_len files: Lib/test/test_weakset.py | 6 +++++- 1 files changed, 5 insertions(+), 1 deletions(-) diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py --- a/Lib/test/test_weakset.py +++ b/Lib/test/test_weakset.py @@ -84,7 +84,11 @@ self.assertEqual(len(self.fs), 1) del self.obj gc.collect() - self.assertEqual(len(self.fs), 0) + # len of weak collections is eventually consistent on + # Jython. In practice this does not matter because of the + # nature of weaksets - we cannot rely on what happens in the + # reaper thread and how it interacts with gc + self.assertIn(len(self.fs), (0, 1)) def test_contains(self): for c in self.letters: -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Mon Dec 15 21:03:53 2014 From: jython-checkins at python.org (santoso.wijaya) Date: Mon, 15 Dec 2014 20:03:53 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Fix_KeyError_message_when_p?= =?utf-8?q?opping_a_non-existent_dict_key=2E?= Message-ID: <20141215200346.18147.68773@psf.io> https://hg.python.org/jython/rev/f68d65a141f7 changeset: 7458:f68d65a141f7 user: Santoso Wijaya date: Mon Dec 15 12:04:24 2014 -0800 summary: Fix KeyError message when popping a non-existent dict key. files: src/org/python/core/PyDictionary.java | 2 +- 1 files changed, 1 insertions(+), 1 deletions(-) diff --git a/src/org/python/core/PyDictionary.java b/src/org/python/core/PyDictionary.java --- a/src/org/python/core/PyDictionary.java +++ b/src/org/python/core/PyDictionary.java @@ -618,7 +618,7 @@ final PyObject dict_pop(PyObject key, PyObject defaultValue) { if (!getMap().containsKey(key)) { if (defaultValue == null) { - throw Py.KeyError("popitem(): dictionary is empty"); + throw Py.KeyError(key.asString()); } return defaultValue; } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 16 23:46:08 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 16 Dec 2014 22:46:08 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Refactors_org=2Epython=2Eco?= =?utf-8?q?re=2Eimp?= Message-ID: <20141216224548.92269.17043@psf.io> https://hg.python.org/jython/rev/0ee9d9d530ba changeset: 7459:0ee9d9d530ba user: Jim Baker date: Tue Dec 16 15:45:44 2014 -0700 summary: Refactors org.python.core.imp All possbile calls to the top-level module script during import are now guarded by a module import lock, which is now per PySystemState. Removes statics for the module import lock and the SyspathLoader from org.python.core.imp. Instead, make them instances of PySystemState. In the case of the SyspathLoader (a ClassLoader), this change removes another source of possible resource leaks. Fixes http://bugs.jython.org/issue2205 files: Lib/test/test_import.py | 648 +++++++++++++ src/org/python/core/PySystemState.java | 15 + src/org/python/core/imp.java | 50 +- src/org/python/modules/_imp.java | 6 +- 4 files changed, 697 insertions(+), 22 deletions(-) diff --git a/Lib/test/test_import.py b/Lib/test/test_import.py new file mode 100644 --- /dev/null +++ b/Lib/test/test_import.py @@ -0,0 +1,648 @@ +import errno +import imp +import marshal +import os +import py_compile +import random +import stat +import struct +import sys +import unittest +import textwrap +import shutil + +from test.test_support import (unlink, TESTFN, unload, run_unittest, rmtree, + is_jython, check_warnings, EnvironmentVarGuard) +from test import symlink_support +from test import script_helper + +def _files(name): + return (name + os.extsep + "py", + name + os.extsep + "pyc", + name + os.extsep + "pyo", + name + os.extsep + "pyw", + name + "$py.class") + +def chmod_files(name): + for f in _files(name): + try: + os.chmod(f, 0600) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise + +def remove_files(name): + for f in _files(name): + unlink(f) + + +class ImportTests(unittest.TestCase): + + def tearDown(self): + unload(TESTFN) + setUp = tearDown + + def test_case_sensitivity(self): + # Brief digression to test that import is case-sensitive: if we got + # this far, we know for sure that "random" exists. + try: + import RAnDoM + except ImportError: + pass + else: + self.fail("import of RAnDoM should have failed (case mismatch)") + + def test_double_const(self): + # Another brief digression to test the accuracy of manifest float + # constants. + from test import double_const # don't blink -- that *was* the test + + def test_import(self): + def test_with_extension(ext): + # The extension is normally ".py", perhaps ".pyw". + source = TESTFN + ext + pyo = TESTFN + os.extsep + "pyo" + if is_jython: + pyc = TESTFN + "$py.class" + else: + pyc = TESTFN + os.extsep + "pyc" + + with open(source, "w") as f: + print >> f, ("# This tests Python's ability to import a", ext, + "file.") + a = random.randrange(1000) + b = random.randrange(1000) + print >> f, "a =", a + print >> f, "b =", b + + try: + mod = __import__(TESTFN) + except ImportError, err: + self.fail("import from %s failed: %s" % (ext, err)) + else: + self.assertEqual(mod.a, a, + "module loaded (%s) but contents invalid" % mod) + self.assertEqual(mod.b, b, + "module loaded (%s) but contents invalid" % mod) + finally: + unlink(source) + + try: + imp.reload(mod) + except ImportError, err: + self.fail("import from .pyc/.pyo failed: %s" % err) + finally: + unlink(pyc) + unlink(pyo) + unload(TESTFN) + + sys.path.insert(0, os.curdir) + try: + test_with_extension(os.extsep + "py") + if sys.platform.startswith("win"): + for ext in [".PY", ".Py", ".pY", ".pyw", ".PYW", ".pYw"]: + test_with_extension(ext) + finally: + del sys.path[0] + + @unittest.skipUnless(os.name == 'posix', "test meaningful only on posix systems") + def test_execute_bit_not_copied(self): + # Issue 6070: under posix .pyc files got their execute bit set if + # the .py file had the execute bit set, but they aren't executable. + oldmask = os.umask(022) + sys.path.insert(0, os.curdir) + try: + fname = TESTFN + os.extsep + "py" + f = open(fname, 'w').close() + os.chmod(fname, (stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH | + stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)) + __import__(TESTFN) + fn = fname + 'c' + if not os.path.exists(fn): + fn = fname + 'o' + if not os.path.exists(fn): + self.fail("__import__ did not result in creation of " + "either a .pyc or .pyo file") + s = os.stat(fn) + self.assertEqual(stat.S_IMODE(s.st_mode), + stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) + finally: + os.umask(oldmask) + remove_files(TESTFN) + unload(TESTFN) + del sys.path[0] + + def test_rewrite_pyc_with_read_only_source(self): + # Issue 6074: a long time ago on posix, and more recently on Windows, + # a read only source file resulted in a read only pyc file, which + # led to problems with updating it later + sys.path.insert(0, os.curdir) + fname = TESTFN + os.extsep + "py" + try: + # Write a Python file, make it read-only and import it + with open(fname, 'w') as f: + f.write("x = 'original'\n") + # Tweak the mtime of the source to ensure pyc gets updated later + s = os.stat(fname) + os.utime(fname, (s.st_atime, s.st_mtime-100000000)) + os.chmod(fname, 0400) + m1 = __import__(TESTFN) + self.assertEqual(m1.x, 'original') + # Change the file and then reimport it + os.chmod(fname, 0600) + with open(fname, 'w') as f: + f.write("x = 'rewritten'\n") + unload(TESTFN) + m2 = __import__(TESTFN) + self.assertEqual(m2.x, 'rewritten') + # Now delete the source file and check the pyc was rewritten + unlink(fname) + unload(TESTFN) + m3 = __import__(TESTFN) + self.assertEqual(m3.x, 'rewritten') + finally: + chmod_files(TESTFN) + remove_files(TESTFN) + unload(TESTFN) + del sys.path[0] + + def test_imp_module(self): + # Verify that the imp module can correctly load and find .py files + + # XXX (ncoghlan): It would be nice to use test_support.CleanImport + # here, but that breaks because the os module registers some + # handlers in copy_reg on import. Since CleanImport doesn't + # revert that registration, the module is left in a broken + # state after reversion. Reinitialising the module contents + # and just reverting os.environ to its previous state is an OK + # workaround + orig_path = os.path + orig_getenv = os.getenv + with EnvironmentVarGuard(): + x = imp.find_module("os") + new_os = imp.load_module("os", *x) + self.assertIs(os, new_os) + self.assertIs(orig_path, new_os.path) + self.assertIsNot(orig_getenv, new_os.getenv) + + def test_module_with_large_stack(self, module='longlist'): + # Regression test for http://bugs.python.org/issue561858. + filename = module + os.extsep + 'py' + + # Create a file with a list of 65000 elements. + with open(filename, 'w+') as f: + f.write('d = [\n') + for i in range(65000): + f.write('"",\n') + f.write(']') + + # Compile & remove .py file, we only need .pyc (or .pyo). + with open(filename, 'r') as f: + py_compile.compile(filename) + unlink(filename) + + # Need to be able to load from current dir. + sys.path.append('') + + # This used to crash. + exec 'import ' + module + + # Cleanup. + del sys.path[-1] + unlink(filename + 'c') + unlink(filename + 'o') + + def test_failing_import_sticks(self): + source = TESTFN + os.extsep + "py" + with open(source, "w") as f: + print >> f, "a = 1 // 0" + + # New in 2.4, we shouldn't be able to import that no matter how often + # we try. + sys.path.insert(0, os.curdir) + try: + for i in [1, 2, 3]: + self.assertRaises(ZeroDivisionError, __import__, TESTFN) + self.assertNotIn(TESTFN, sys.modules, + "damaged module in sys.modules on %i try" % i) + finally: + del sys.path[0] + remove_files(TESTFN) + + def test_failing_reload(self): + # A failing reload should leave the module object in sys.modules. + source = TESTFN + os.extsep + "py" + with open(source, "w") as f: + print >> f, "a = 1" + print >> f, "b = 2" + + sys.path.insert(0, os.curdir) + try: + mod = __import__(TESTFN) + self.assertIn(TESTFN, sys.modules) + self.assertEqual(mod.a, 1, "module has wrong attribute values") + self.assertEqual(mod.b, 2, "module has wrong attribute values") + + # On WinXP, just replacing the .py file wasn't enough to + # convince reload() to reparse it. Maybe the timestamp didn't + # move enough. We force it to get reparsed by removing the + # compiled file too. + remove_files(TESTFN) + + # Now damage the module. + with open(source, "w") as f: + print >> f, "a = 10" + print >> f, "b = 20//0" + + self.assertRaises(ZeroDivisionError, imp.reload, mod) + + # But we still expect the module to be in sys.modules. + mod = sys.modules.get(TESTFN) + self.assertIsNot(mod, None, "expected module to be in sys.modules") + + # We should have replaced a w/ 10, but the old b value should + # stick. + self.assertEqual(mod.a, 10, "module has wrong attribute values") + self.assertEqual(mod.b, 2, "module has wrong attribute values") + + finally: + del sys.path[0] + remove_files(TESTFN) + unload(TESTFN) + + def test_infinite_reload(self): + # http://bugs.python.org/issue742342 reports that Python segfaults + # (infinite recursion in C) when faced with self-recursive reload()ing. + + sys.path.insert(0, os.path.dirname(__file__)) + try: + import infinite_reload + finally: + del sys.path[0] + + def test_import_name_binding(self): + # import x.y.z binds x in the current namespace. + import test as x + import test.test_support + self.assertIs(x, test, x.__name__) + self.assertTrue(hasattr(test.test_support, "__file__")) + + # import x.y.z as w binds z as w. + import test.test_support as y + self.assertIs(y, test.test_support, y.__name__) + + def test_import_initless_directory_warning(self): + # FIXME this is tricky - how does interact with importing Java code? + # jars are easy, but what about a directory of classes? + with check_warnings(('', ImportWarning)): + # Just a random non-package directory we always expect to be + # somewhere in sys.path... + self.assertRaises(ImportError, __import__, "site-packages") + + def test_import_by_filename(self): + path = os.path.abspath(TESTFN) + with self.assertRaises(ImportError) as c: + __import__(path) + self.assertEqual("Import by filename is not supported.", + c.exception.args[0]) + + def test_import_in_del_does_not_crash(self): + # Issue 4236 + testfn = script_helper.make_script('', TESTFN, textwrap.dedent("""\ + import sys + class C: + def __del__(self): + import imp + sys.argv.insert(0, C()) + """)) + try: + script_helper.assert_python_ok(testfn) + finally: + unlink(testfn) + + def test_bug7732(self): + source = TESTFN + '.py' + os.mkdir(source) + try: + self.assertRaises((ImportError, IOError), + imp.find_module, TESTFN, ["."]) + finally: + os.rmdir(source) + + def test_timestamp_overflow(self): + # A modification timestamp larger than 2**32 should not be a problem + # when importing a module (issue #11235). + sys.path.insert(0, os.curdir) + try: + source = TESTFN + ".py" + compiled = source + ('c' if __debug__ else 'o') + with open(source, 'w') as f: + pass + try: + os.utime(source, (2 ** 33 - 5, 2 ** 33 - 5)) + except OverflowError: + self.skipTest("cannot set modification time to large integer") + except OSError as e: + if e.errno != getattr(errno, 'EOVERFLOW', None): + raise + self.skipTest("cannot set modification time to large integer ({})".format(e)) + __import__(TESTFN) + # The pyc file was created. + os.stat(compiled) + finally: + del sys.path[0] + remove_files(TESTFN) + + def test_pyc_mtime(self): + # Test for issue #13863: .pyc timestamp sometimes incorrect on Windows. + sys.path.insert(0, os.curdir) + try: + # Jan 1, 2012; Jul 1, 2012. + mtimes = 1325376000, 1341100800 + + # Different names to avoid running into import caching. + tails = "spam", "eggs" + for mtime, tail in zip(mtimes, tails): + module = TESTFN + tail + source = module + ".py" + compiled = source + ('c' if __debug__ else 'o') + + # Create a new Python file with the given mtime. + with open(source, 'w') as f: + f.write("# Just testing\nx=1, 2, 3\n") + os.utime(source, (mtime, mtime)) + + # Generate the .pyc/o file; if it couldn't be created + # for some reason, skip the test. + m = __import__(module) + if not os.path.exists(compiled): + unlink(source) + self.skipTest("Couldn't create .pyc/.pyo file.") + + # Actual modification time of .py file. + mtime1 = int(os.stat(source).st_mtime) & 0xffffffff + + # mtime that was encoded in the .pyc file. + with open(compiled, 'rb') as f: + mtime2 = struct.unpack(' sample-tagged + symlink_support.symlink(self.tagged, self.package_name) + + assert os.path.isdir(self.package_name) + assert os.path.isfile(os.path.join(self.package_name, '__init__.py')) + + @property + def tagged(self): + return self.package_name + '-tagged' + + # regression test for issue6727 + @unittest.skipUnless( + not hasattr(sys, 'getwindowsversion') + or sys.getwindowsversion() >= (6, 0), + "Windows Vista or later required") + @symlink_support.skip_unless_symlink + def test_symlinked_dir_importable(self): + # make sure sample can only be imported from the current directory. + sys.path[:] = ['.'] + + # and try to import the package + __import__(self.package_name) + + def tearDown(self): + # now cleanup + if os.path.exists(self.package_name): + symlink_support.remove_symlink(self.package_name) + if os.path.exists(self.tagged): + shutil.rmtree(self.tagged) + sys.path[:] = self.orig_sys_path + +def test_main(verbose=None): + run_unittest(ImportTests, PycRewritingTests, PathsTests, + RelativeImportTests, TestSymbolicallyLinkedPackage) + +if __name__ == '__main__': + # Test needs to be a package, so we can do relative imports. + from test.test_import import test_main + test_main() diff --git a/src/org/python/core/PySystemState.java b/src/org/python/core/PySystemState.java --- a/src/org/python/core/PySystemState.java +++ b/src/org/python/core/PySystemState.java @@ -29,6 +29,7 @@ import java.util.StringTokenizer; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.locks.ReentrantLock; import java.util.jar.JarEntry; import java.util.jar.JarFile; @@ -130,6 +131,9 @@ public PyList argv = new PyList(); public PyObject modules; + public PyObject modules_reloading; + private ReentrantLock importLock; + private ClassLoader syspathJavaLoader; public PyList path; public PyList warnoptions = new PyList(); @@ -190,6 +194,9 @@ initialize(); closer = new PySystemStateCloser(this); modules = new PyStringMap(); + modules_reloading = new PyStringMap(); + importLock = new ReentrantLock(); + syspathJavaLoader = new SyspathJavaLoader(imp.getParentClassLoader()); argv = (PyList)defaultArgv.repeat(1); path = (PyList)defaultPath.repeat(1); @@ -337,6 +344,14 @@ return codecState; } + public ReentrantLock getImportLock() { + return importLock; + } + + public ClassLoader getSyspathJavaLoader() { + return syspathJavaLoader; + } + // xxx fix this accessors @Override public PyObject __findattr_ex__(String name) { diff --git a/src/org/python/core/imp.java b/src/org/python/core/imp.java --- a/src/org/python/core/imp.java +++ b/src/org/python/core/imp.java @@ -30,26 +30,15 @@ public static final int NO_MTIME = -1; - // This should change to 0 for Python 2.7 and 3.0 see PEP 328 + // This should change to Python 3.x; note that 2.7 allows relative + // imports unless `from __future__ import absolute_import` public static final int DEFAULT_LEVEL = -1; /** A non-empty fromlist for __import__'ing sub-modules. */ private static final PyObject nonEmptyFromlist = new PyTuple(Py.newString("__doc__")); - /** Synchronizes import operations */ - public static final ReentrantLock importLock = new ReentrantLock(); - - private static Object syspathJavaLoaderLock = new Object(); - - private static ClassLoader syspathJavaLoader = null; - public static ClassLoader getSyspathJavaLoader() { - synchronized (syspathJavaLoaderLock) { - if (syspathJavaLoader == null) { - syspathJavaLoader = new SyspathJavaLoader(getParentClassLoader()); - } - } - return syspathJavaLoader; + return Py.getSystemState().getSyspathJavaLoader(); } /** @@ -392,14 +381,18 @@ Py.writeDebug(IMPORT_LOG, String.format("Warning: %s __file__ is unknown", name)); } + ReentrantLock importLock = Py.getSystemState().getImportLock(); + importLock.lock(); try { PyFrame f = new PyFrame(code, module.__dict__, module.__dict__, null); code.call(Py.getThreadState(), f); + return module; } catch (RuntimeException t) { removeModule(name); throw t; + } finally { + importLock.unlock(); } - return module; } static PyObject createFromClass(String name, Class c) { @@ -521,7 +514,13 @@ static PyObject loadFromLoader(PyObject importer, String name) { PyObject load_module = importer.__getattr__("load_module"); - return load_module.__call__(new PyObject[] { new PyString(name) }); + ReentrantLock importLock = Py.getSystemState().getImportLock(); + importLock.lock(); + try { + return load_module.__call__(new PyObject[]{new PyString(name)}); + } finally { + importLock.unlock(); + } } public static PyObject loadFromCompiled(String name, InputStream stream, String sourceName, @@ -634,7 +633,13 @@ * @return the loaded module */ public static PyObject load(String name) { - return import_first(name, new StringBuilder()); + ReentrantLock importLock = Py.getSystemState().getImportLock(); + importLock.lock(); + try { + return import_first(name, new StringBuilder()); + } finally { + importLock.unlock(); + } } /** @@ -931,7 +936,13 @@ * @return an imported module (Java or Python) */ public static PyObject importName(String name, boolean top) { - return import_module_level(name, top, null, null, DEFAULT_LEVEL); + ReentrantLock importLock = Py.getSystemState().getImportLock(); + importLock.lock(); + try { + return import_module_level(name, top, null, null, DEFAULT_LEVEL); + } finally { + importLock.unlock(); + } } /** @@ -945,6 +956,7 @@ */ public static PyObject importName(String name, boolean top, PyObject modDict, PyObject fromlist, int level) { + ReentrantLock importLock = Py.getSystemState().getImportLock(); importLock.lock(); try { return import_module_level(name, top, modDict, fromlist, level); @@ -959,7 +971,7 @@ */ @Deprecated public static PyObject importOne(String mod, PyFrame frame) { - return importOne(mod, frame, imp.DEFAULT_LEVEL); + return importOne(mod, frame, imp.DEFAULT_LEVEL); } /** * Called from jython generated code when a statement like "import spam" is diff --git a/src/org/python/modules/_imp.java b/src/org/python/modules/_imp.java --- a/src/org/python/modules/_imp.java +++ b/src/org/python/modules/_imp.java @@ -299,7 +299,7 @@ * */ public static void acquire_lock() { - org.python.core.imp.importLock.lock(); + Py.getSystemState().getImportLock().lock(); } /** @@ -308,7 +308,7 @@ */ public static void release_lock() { try{ - org.python.core.imp.importLock.unlock(); + Py.getSystemState().getImportLock().unlock(); }catch(IllegalMonitorStateException e){ throw Py.RuntimeError("not holding the import lock"); } @@ -320,6 +320,6 @@ * @return true if the import lock is currently held, else false. */ public static boolean lock_held() { - return org.python.core.imp.importLock.isHeldByCurrentThread(); + return Py.getSystemState().getImportLock().isHeldByCurrentThread(); } } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 17 01:31:34 2014 From: jython-checkins at python.org (jim.baker) Date: Wed, 17 Dec 2014 00:31:34 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Change_dict=2Epop_such_that?= =?utf-8?q?_KeyError_uses_toString=28=29?= Message-ID: <20141217003129.28218.23473@psf.io> https://hg.python.org/jython/rev/40daa8bd3613 changeset: 7460:40daa8bd3613 user: Jim Baker date: Tue Dec 16 17:31:07 2014 -0700 summary: Change dict.pop such that KeyError uses toString() Minor fix from https://hg.python.org/jython/rev/f68d65a141f7, which was using key.asString() - this was throwing a TypeError for keys that were not str/unicode files: src/org/python/core/PyDictionary.java | 2 +- 1 files changed, 1 insertions(+), 1 deletions(-) diff --git a/src/org/python/core/PyDictionary.java b/src/org/python/core/PyDictionary.java --- a/src/org/python/core/PyDictionary.java +++ b/src/org/python/core/PyDictionary.java @@ -618,7 +618,7 @@ final PyObject dict_pop(PyObject key, PyObject defaultValue) { if (!getMap().containsKey(key)) { if (defaultValue == null) { - throw Py.KeyError(key.asString()); + throw Py.KeyError(key.toString()); } return defaultValue; } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 17 01:35:13 2014 From: jython-checkins at python.org (jim.baker) Date: Wed, 17 Dec 2014 00:35:13 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_No_longer_causes_a_stack_ov?= =?utf-8?q?erflow_on_=22infinite_reloads=22?= Message-ID: <20141217003457.28226.55480@psf.io> https://hg.python.org/jython/rev/8708596b329f changeset: 7461:8708596b329f user: Jim Baker date: Tue Dec 16 17:34:50 2014 -0700 summary: No longer causes a stack overflow on "infinite reloads" Previously infinite_reload.py in test_import would cause a RuntimeError due to a stack overflow. Now handles the same as CPython by tracking which module is currently being reloaded, and returns immediately with that module. files: src/org/python/core/PySystemState.java | 5 +- src/org/python/core/imp.java | 26 ++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/org/python/core/PySystemState.java b/src/org/python/core/PySystemState.java --- a/src/org/python/core/PySystemState.java +++ b/src/org/python/core/PySystemState.java @@ -19,6 +19,7 @@ import java.nio.charset.Charset; import java.nio.charset.UnsupportedCharsetException; import java.security.AccessControlException; +import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashSet; import java.util.LinkedList; @@ -131,7 +132,7 @@ public PyList argv = new PyList(); public PyObject modules; - public PyObject modules_reloading; + public Map modules_reloading; private ReentrantLock importLock; private ClassLoader syspathJavaLoader; public PyList path; @@ -194,7 +195,7 @@ initialize(); closer = new PySystemStateCloser(this); modules = new PyStringMap(); - modules_reloading = new PyStringMap(); + modules_reloading = new HashMap(); importLock = new ReentrantLock(); syspathJavaLoader = new SyspathJavaLoader(imp.getParentClassLoader()); diff --git a/src/org/python/core/imp.java b/src/org/python/core/imp.java --- a/src/org/python/core/imp.java +++ b/src/org/python/core/imp.java @@ -7,6 +7,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; +import java.util.Map; import java.util.concurrent.locks.ReentrantLock; import org.python.compiler.Module; @@ -1146,15 +1147,34 @@ } static PyObject reload(PyModule m) { + PySystemState sys = Py.getSystemState(); + PyObject modules = sys.modules; + Map modules_reloading = sys.modules_reloading; + ReentrantLock importLock = Py.getSystemState().getImportLock(); + importLock.lock(); + try { + return _reload(m, modules, modules_reloading); + } finally { + modules_reloading.clear(); + importLock.unlock(); + } + } + + private static PyObject _reload(PyModule m, PyObject modules, Map modules_reloading) { String name = m.__getattr__("__name__").toString().intern(); - - PyObject modules = Py.getSystemState().modules; PyModule nm = (PyModule) modules.__finditem__(name); - if (nm == null || !nm.__getattr__("__name__").toString().equals(name)) { throw Py.ImportError("reload(): module " + name + " not in sys.modules"); } + PyModule existing_module = modules_reloading.get(name); + if (existing_module != null) { + // Due to a recursive reload, this module is already being reloaded. + return existing_module; + } + // Since we are already in a re-entrant lock, + // this test & set is guaranteed to be atomic + modules_reloading.put(name, nm); PyList path = Py.getSystemState().path; String modName = name; -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sat Dec 20 09:10:47 2014 From: jython-checkins at python.org (jim.baker) Date: Sat, 20 Dec 2014 08:10:47 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Fixes_keyword_inspection_an?= =?utf-8?q?d_call_validation=2E?= Message-ID: <20141220081044.28216.47422@psf.io> https://hg.python.org/jython/rev/1b83af91c320 changeset: 7464:1b83af91c320 user: Jim Baker date: Sat Dec 20 01:10:39 2014 -0700 summary: Fixes keyword inspection and call validation. Among other things, Jython now throws a TypeError when invoking no-arg functions with keywords args, such as the example below: def f(): return 42 f(x=47) # now throws TypeError: f() takes no arguments (1 given) Still needs to finish support for more of the tuple args tests in test_inspect files: src/org/python/core/PyBaseCode.java | 49 +++++++++------- 1 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/org/python/core/PyBaseCode.java b/src/org/python/core/PyBaseCode.java --- a/src/org/python/core/PyBaseCode.java +++ b/src/org/python/core/PyBaseCode.java @@ -5,6 +5,7 @@ package org.python.core; import org.python.modules._systemrestart; +import com.google.common.base.CharMatcher; public abstract class PyBaseCode extends PyCode { @@ -199,14 +200,14 @@ public PyObject call(ThreadState state, PyObject args[], String kws[], PyObject globals, PyObject[] defs, PyObject closure) { - PyFrame frame = new PyFrame(this, globals); - int argcount = args.length - kws.length; + final PyFrame frame = new PyFrame(this, globals); + final int argcount = args.length - kws.length; - if (co_argcount > 0 || (varargs || varkwargs)) { + if ((co_argcount > 0) || varargs || varkwargs) { int i; int n = argcount; PyObject kwdict = null; - PyObject[] fastlocals = frame.f_fastlocals; + final PyObject[] fastlocals = frame.f_fastlocals; if (varkwargs) { kwdict = new PyDictionary(); i = co_argcount; @@ -222,9 +223,9 @@ co_name, defcount > 0 ? "at most" : "exactly", co_argcount, - kws.length > 0 ? "non-keyword " : "", + kws.length > 0 ? "" : "", co_argcount == 1 ? "" : "s", - argcount); + args.length); throw Py.TypeError(msg); } n = co_argcount; @@ -242,11 +243,6 @@ String keyword = kws[i]; PyObject value = args[i + argcount]; int j; - // XXX: keywords aren't PyObjects, can't ensure strings - //if (keyword == null || keyword.getClass() != PyString.class) { - // throw Py.TypeError(String.format("%.200s() keywords must be strings", - // co_name)); - //} for (j = 0; j < co_argcount; j++) { if (co_varnames[j].equals(keyword)) { break; @@ -254,11 +250,16 @@ } if (j >= co_argcount) { if (kwdict == null) { - throw Py.TypeError(String.format("%.200s() got an unexpected keyword " - + "argument '%.400s'", - co_name, keyword)); + throw Py.TypeError(String.format( + "%.200s() got an unexpected keyword argument '%.400s'", + co_name, + Py.newUnicode(keyword).encode("ascii", "replace"))); } - kwdict.__setitem__(keyword, value); + if (CharMatcher.ASCII.matchesAllOf(keyword)) { + kwdict.__setitem__(keyword, value); + } else { + kwdict.__setitem__(Py.newUnicode(keyword), value); + } } else { if (fastlocals[j] != null) { throw Py.TypeError(String.format("%.200s() got multiple values for " @@ -269,16 +270,18 @@ } } if (argcount < co_argcount) { - int defcount = defs != null ? defs.length : 0; - int m = co_argcount - defcount; + final int defcount = defs != null ? defs.length : 0; + final int m = co_argcount - defcount; for (i = argcount; i < m; i++) { if (fastlocals[i] == null) { String msg = String.format("%.200s() takes %s %d %sargument%s (%d given)", - co_name, (varargs || defcount > 0) ? - "at least" : "exactly", - m, kws.length > 0 ? "non-keyword " : "", - m == 1 ? "" : "s", i); + co_name, + (varargs || defcount > 0) ? "at least" : "exactly", + m, + kws.length > 0 ? "" : "", + m == 1 ? "" : "s", + args.length); throw Py.TypeError(msg); } } @@ -293,9 +296,9 @@ } } } - } else if (argcount > 0) { + } else if ((argcount > 0) || (args.length > 0 && (co_argcount == 0 && !varargs && !varkwargs))) { throw Py.TypeError(String.format("%.200s() takes no arguments (%d given)", - co_name, argcount)); + co_name, args.length)); } if (co_flags.isFlagSet(CodeFlag.CO_GENERATOR)) { -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 21 07:10:36 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 21 Dec 2014 06:10:36 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Minor_fix_to_test=5Finspect?= =?utf-8?q?_so_that_all_tests_in_it_pass?= Message-ID: <20141221061035.28218.15309@psf.io> https://hg.python.org/jython/rev/66357235ee2e changeset: 7465:66357235ee2e user: Jim Baker date: Sat Dec 20 23:10:32 2014 -0700 summary: Minor fix to test_inspect so that all tests in it pass It's possible to get both the tuple parameters AND the unpacked parameters by using locals(). However, they are named differently in CPython and Jython: * For CPython, such tuple parameters are named '.1', '.2', etc. * For Jython, they are actually the formal parameter, eg '(d, (e, f))' test_inspect actually wants to ignore such tuple parameters when verifying working with function signatures, given they are already unpacked. Updated is_tuplename regex accordingly. files: Lib/test/test_inspect.py | 11 +++++++++-- 1 files changed, 9 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py --- a/Lib/test/test_inspect.py +++ b/Lib/test/test_inspect.py @@ -633,8 +633,15 @@ class TestGetcallargsFunctions(unittest.TestCase): - # tuple parameters are named '.1', '.2', etc. - is_tuplename = re.compile(r'^\.\d+$').match + # It's possible to get both the tuple parameters AND the unpacked + # parameters by using locals(). However, they are named + # differently in CPython and Jython: + # + # * For CPython, such tuple parameters are named '.1', '.2', etc. + # * For Jython, they are actually the formal parameter, eg '(d, (e, f))' + # + # In both cases, we ignore in testing - they are in fact unpacked + is_tuplename = re.compile(r'(?:^\.\d+$)|(?:^\()').match def assertEqualCallArgs(self, func, call_params_string, locs=None): locs = dict(locs or {}, func=func) -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 23 17:41:49 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 23 Dec 2014 16:41:49 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Adds_comparison_ops_to_prox?= =?utf-8?q?ied_java=2Eutil=2EMap_objects?= Message-ID: <20141223164129.120061.35633@psf.io> https://hg.python.org/jython/rev/fc24f60fb32c changeset: 7466:fc24f60fb32c user: Jim Baker date: Tue Dec 23 09:41:23 2014 -0700 summary: Adds comparison ops to proxied java.util.Map objects Completes fix for http://bugs.jython.org/issue1631 by adding __le__, __lt__, __gt__, __ge__ comparison operators. Now any proxied Map objects are compared by their key sets for ordering comparisons. Previously the default comparison was done, which is by their class name (eg "HashMap") and then by System.identityHashCode. (This default comparison is removed in Python 3.) Note that when comparing java.util.Map objects with each other, Java semantics apply. This means that keys are not converted to Python, then compared. Likely the only scenario this comes up is creating a Java Map from Python objects, specifically of int and long objects, then doing a comparison. Example: HashMap({1L:2L}) != HashMap({1:2}) but HashMap({1L:2L}) == {1:2} Supporting Python 3 will fix this problem by removing the distinction between say 1 and 1L. files: Lib/test/test_dict_jy.py | 69 ++++++- src/org/python/core/PyJavaType.java | 150 +++++++++++---- 2 files changed, 164 insertions(+), 55 deletions(-) diff --git a/Lib/test/test_dict_jy.py b/Lib/test/test_dict_jy.py --- a/Lib/test/test_dict_jy.py +++ b/Lib/test/test_dict_jy.py @@ -1,5 +1,5 @@ from test import test_support -import java +from java.util import HashMap, Hashtable import unittest from collections import defaultdict import test_dict @@ -114,7 +114,7 @@ class JavaIntegrationTest(unittest.TestCase): "Tests for instantiating dicts from Java maps and hashtables" def test_hashmap(self): - x = java.util.HashMap() + x = HashMap() x.put('a', 1) x.put('b', 2) x.put('c', 3) @@ -123,7 +123,7 @@ self.assertEqual(set(y.items()), set([('a', 1), ('b', 2), ('c', 3), ((1,2), "xyz")])) def test_hashmap_builtin_pymethods(self): - x = java.util.HashMap() + x = HashMap() x['a'] = 1 x[(1, 2)] = 'xyz' self.assertEqual({tup for tup in x.iteritems()}, {('a', 1), ((1, 2), 'xyz')}) @@ -132,18 +132,18 @@ def test_hashtable_equal(self): for d in ({}, {1:2}): - x = java.util.Hashtable(d) + x = Hashtable(d) self.assertEqual(x, d) self.assertEqual(d, x) - self.assertEqual(x, java.util.HashMap(d)) + self.assertEqual(x, HashMap(d)) def test_hashtable_remove(self): - x = java.util.Hashtable({}) + x = Hashtable({}) with self.assertRaises(KeyError): del x[0] def test_hashtable(self): - x = java.util.Hashtable() + x = Hashtable() x.put('a', 1) x.put('b', 2) x.put('c', 3) @@ -154,10 +154,10 @@ class JavaDictTest(test_dict.DictTest): - _class = java.util.HashMap + _class = HashMap def test_copy_java_hashtable(self): - x = java.util.Hashtable() + x = Hashtable() xc = x.copy() self.assertEqual(type(x), type(xc)) @@ -179,6 +179,57 @@ self.assertEqual(x.__delitem__(1), None) self.assertEqual(len(x), 0) + def assert_property(self, prop, a, b): + prop(self._make_dict(a), self._make_dict(b)) + prop(a, self._make_dict(b)) + prop(self._make_dict(a), b) + + def assert_not_property(self, prop, a, b): + with self.assertRaises(AssertionError): + prop(self._make_dict(a), self._make_dict(b)) + with self.assertRaises(AssertionError): + prop(a, self._make_dict(b)) + with self.assertRaises(AssertionError): + prop(self._make_dict(a), b) + + # NOTE: when comparing dictionaries below exclusively in Java + # space, keys like 1 and 1L are different objects. Only when they + # are brought into Python space by Py.java2py, as is needed when + # comparing a Python dict with a Java Map, do we see them become + # equal. + + def test_le(self): + self.assert_property(self.assertLessEqual, {}, {}) + self.assert_property(self.assertLessEqual, {1: 2}, {1: 2}) + self.assert_not_property(self.assertLessEqual, {1: 2, 3: 4}, {1: 2}) + self.assert_property(self.assertLessEqual, {}, {1: 2}) + self.assertLessEqual(self._make_dict({1: 2}), {1L: 2L, 3L: 4L}) + self.assertLessEqual({1L: 2L}, self._make_dict({1: 2, 3L: 4L})) + + def test_lt(self): + self.assert_not_property(self.assertLess, {}, {}) + self.assert_not_property(self.assertLess, {1: 2}, {1: 2}) + self.assert_not_property(self.assertLessEqual, {1: 2, 3: 4}, {1: 2}) + self.assert_property(self.assertLessEqual, {}, {1: 2}) + self.assertLess(self._make_dict({1: 2}), {1L: 2L, 3L: 4L}) + self.assertLess({1L: 2L}, self._make_dict({1: 2, 3L: 4L})) + + def test_ge(self): + self.assert_property(self.assertGreaterEqual, {}, {}) + self.assert_property(self.assertGreaterEqual, {1: 2}, {1: 2}) + self.assert_not_property(self.assertLessEqual, {1: 2, 3: 4}, {1: 2}) + self.assert_property(self.assertLessEqual, {}, {1: 2}) + self.assertGreaterEqual(self._make_dict({1: 2, 3: 4}), {1L: 2L}) + self.assertGreaterEqual({1L: 2L, 3L: 4L}, self._make_dict({1: 2})) + + def test_gt(self): + self.assert_not_property(self.assertGreater, {}, {}) + self.assert_not_property(self.assertGreater, {1: 2}, {1: 2}) + self.assert_not_property(self.assertLessEqual, {1: 2, 3: 4}, {1: 2}) + self.assert_property(self.assertLessEqual, {}, {1: 2}) + self.assertGreater(self._make_dict({1: 2, 3: 4}), {1L: 2L}) + self.assertGreater({1L: 2L, 3L: 4L}, self._make_dict({1: 2})) + def test_main(): test_support.run_unittest(DictInitTest, DictCmpTest, DerivedDictTest, JavaIntegrationTest, JavaDictTest) diff --git a/src/org/python/core/PyJavaType.java b/src/org/python/core/PyJavaType.java --- a/src/org/python/core/PyJavaType.java +++ b/src/org/python/core/PyJavaType.java @@ -946,6 +946,58 @@ protected abstract boolean getResult(int comparison); } + private static PyObject mapEq(PyObject self, PyObject other) { + Map selfMap = ((Map) self.getJavaProxy()); + if (other.getType().isSubType(PyDictionary.TYPE)) { + PyDictionary oDict = (PyDictionary) other; + if (selfMap.size() != oDict.size()) { + return Py.False; + } + for (Object jkey : selfMap.keySet()) { + Object jval = selfMap.get(jkey); + PyObject oVal = oDict.__finditem__(Py.java2py(jkey)); + if (oVal == null) { + return Py.False; + } + if (!Py.java2py(jval)._eq(oVal).__nonzero__()) { + return Py.False; + } + } + return Py.True; + } else { + Object oj = other.getJavaProxy(); + if (oj instanceof Map) { + Map oMap = (Map) oj; + return Py.newBoolean(selfMap.equals(oMap)); + } else { + return null; + } + } + } + + // Map ordering comparisons (lt, le, gt, ge) are based on the key sets; + // we just define mapLe + mapEq for total ordering of such key sets + private static PyObject mapLe(PyObject self, PyObject other) { + Set selfKeys = ((Map) self.getJavaProxy()).keySet(); + if (other.getType().isSubType(PyDictionary.TYPE)) { + PyDictionary oDict = (PyDictionary) other; + for (Object jkey : selfKeys) { + if (!oDict.__contains__(Py.java2py(jkey))) { + return Py.False; + } + } + return Py.True; + } else { + Object oj = other.getJavaProxy(); + if (oj instanceof Map) { + Map oMap = (Map) oj; + return Py.newBoolean(oMap.keySet().containsAll(selfKeys)); + } else { + return null; + } + } + } + /** * Build a map of common Java collection base types (Map, Iterable, etc) that need to be * injected with Python's equivalent types' builtin methods (__len__, __iter__, iteritems, etc). @@ -1030,31 +1082,31 @@ PyBuiltinMethodNarrow mapEqProxy = new MapMethod("__eq__", 1) { @Override public PyObject __call__(PyObject other) { - if (other.getType().isSubType(PyDictionary.TYPE)) { - PyDictionary oDict = (PyDictionary) other; - if (asMap().size() != oDict.size()) { - return Py.False; - } - for (Object jkey : asMap().keySet()) { - Object jval = asMap().get(jkey); - PyObject oVal = oDict.__finditem__(Py.java2py(jkey)); - if (oVal == null) { - return Py.False; - } - if (!Py.java2py(jval)._eq(oVal).__nonzero__()) { - return Py.False; - } - } - return Py.True; - } else { - Object oj = other.getJavaProxy(); - if (oj instanceof Map) { - Map oMap = (Map) oj; - return asMap().equals(oMap) ? Py.True : Py.False; - } else { - return null; - } - } + return mapEq(self, other); + } + }; + PyBuiltinMethodNarrow mapLeProxy = new MapMethod("__le__", 1) { + @Override + public PyObject __call__(PyObject other) { + return mapLe(self, other); + } + }; + PyBuiltinMethodNarrow mapGeProxy = new MapMethod("__ge__", 1) { + @Override + public PyObject __call__(PyObject other) { + return (mapLe(self, other).__not__()).__or__(mapEq(self, other)); + } + }; + PyBuiltinMethodNarrow mapLtProxy = new MapMethod("__lt__", 1) { + @Override + public PyObject __call__(PyObject other) { + return mapLe(self, other).__and__(mapEq(self, other).__not__()); + } + }; + PyBuiltinMethodNarrow mapGtProxy = new MapMethod("__gt__", 1) { + @Override + public PyObject __call__(PyObject other) { + return mapLe(self, other).__not__(); } }; PyBuiltinMethodNarrow mapIterProxy = new MapMethod("__iter__", 0) { @@ -1340,27 +1392,33 @@ } } }; - collectionProxies.put(Map.class, new PyBuiltinMethod[] {mapLenProxy, - // map IterProxy can conflict with Iterable.class; fix after the fact in handleMroError - mapIterProxy, - mapReprProxy, - mapEqProxy, - mapContainsProxy, - mapGetItemProxy, - //mapGetProxy, - mapPutProxy, - mapRemoveProxy, - mapIterItemsProxy, - mapHasKeyProxy, - mapKeysProxy, - //mapValuesProxy, - mapSetDefaultProxy, - mapPopProxy, - mapPopItemProxy, - mapItemsProxy, - mapCopyProxy, - mapUpdateProxy, - mapFromKeysProxy}); // class method + collectionProxies.put(Map.class, new PyBuiltinMethod[] { + mapLenProxy, + // map IterProxy can conflict with Iterable.class; + // fix after the fact in handleMroError + mapIterProxy, + mapReprProxy, + mapEqProxy, + mapLeProxy, + mapLtProxy, + mapGeProxy, + mapGtProxy, + mapContainsProxy, + mapGetItemProxy, + //mapGetProxy, + mapPutProxy, + mapRemoveProxy, + mapIterItemsProxy, + mapHasKeyProxy, + mapKeysProxy, + //mapValuesProxy, + mapSetDefaultProxy, + mapPopProxy, + mapPopItemProxy, + mapItemsProxy, + mapCopyProxy, + mapUpdateProxy, + mapFromKeysProxy}); // class method postCollectionProxies.put(Map.class, new PyBuiltinMethod[] {mapGetProxy, mapValuesProxy}); -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 24 01:20:54 2014 From: jython-checkins at python.org (jim.baker) Date: Wed, 24 Dec 2014 00:20:54 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Triage_and_fix_remaining_te?= =?utf-8?q?st=5Fjson_failures?= Message-ID: <20141224002049.102363.38443@psf.io> https://hg.python.org/jython/rev/5b7c389f667b changeset: 7467:5b7c389f667b user: Jim Baker date: Tue Dec 23 17:20:13 2014 -0700 summary: Triage and fix remaining test_json failures Fixed bug in circular reference checking, while simplifying. Skip test_tool failing tests until http://bugs.jython.org/issue695383 is fixed. Check encoding name is itself ascii in creating a JSON encoder as well as constructing a unicode object, raising UnicodeEncodeError otherwise. files: Lib/json/tests/test_tool.py | 75 ++++++++++ src/org/python/core/PyUnicode.java | 10 +- src/org/python/modules/_json/Encoder.java | 80 +++------- src/org/python/modules/_json/Scanner.java | 19 ++- 4 files changed, 125 insertions(+), 59 deletions(-) diff --git a/Lib/json/tests/test_tool.py b/Lib/json/tests/test_tool.py new file mode 100644 --- /dev/null +++ b/Lib/json/tests/test_tool.py @@ -0,0 +1,75 @@ +import os +import sys +import textwrap +import unittest +import subprocess +from test import test_support +from test.script_helper import assert_python_ok + +class TestTool(unittest.TestCase): + data = """ + + [["blorpie"],[ "whoops" ] , [ + ],\t"d-shtaeou",\r"d-nthiouh", + "i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field" + :"yes"} ] + """ + + expect = textwrap.dedent("""\ + [ + [ + "blorpie" + ], + [ + "whoops" + ], + [], + "d-shtaeou", + "d-nthiouh", + "i-vhbjkhnth", + { + "nifty": 87 + }, + { + "field": "yes", + "morefield": false + } + ] + """) + + @unittest.skipIf(test_support.is_jython, "Revisit when http://bugs.jython.org/issue695383 is fixed") + def test_stdin_stdout(self): + proc = subprocess.Popen( + (sys.executable, '-m', 'json.tool'), + stdin=subprocess.PIPE, stdout=subprocess.PIPE) + out, err = proc.communicate(self.data.encode()) + self.assertEqual(out.splitlines(), self.expect.encode().splitlines()) + self.assertEqual(err, None) + + def _create_infile(self): + infile = test_support.TESTFN + with open(infile, "w") as fp: + self.addCleanup(os.remove, infile) + fp.write(self.data) + return infile + + # This is a problem orthogonal to json support, even for usage of + # this tool. Instead it seems to be a problem in simply testing + # it. TODO fix this underlying issue that's been outstanding for a + # while in Jython. + @unittest.skipIf(test_support.is_jython, "Revisit when http://bugs.jython.org/issue695383 is fixed") + def test_infile_stdout(self): + infile = self._create_infile() + rc, out, err = assert_python_ok('-m', 'json.tool', infile) + self.assertEqual(out.splitlines(), self.expect.encode().splitlines()) + self.assertEqual(err, b'') + + def test_infile_outfile(self): + infile = self._create_infile() + outfile = test_support.TESTFN + '.out' + rc, out, err = assert_python_ok('-m', 'json.tool', infile, outfile) + self.addCleanup(os.remove, outfile) + with open(outfile, "r") as fp: + self.assertEqual(fp.read(), self.expect) + self.assertEqual(out, b'') + self.assertEqual(err, b'') diff --git a/src/org/python/core/PyUnicode.java b/src/org/python/core/PyUnicode.java --- a/src/org/python/core/PyUnicode.java +++ b/src/org/python/core/PyUnicode.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Set; +import com.google.common.base.CharMatcher; import org.python.core.stringlib.FieldNameIterator; import org.python.core.stringlib.MarkupIterator; import org.python.expose.ExposedMethod; @@ -578,6 +579,11 @@ return string.length() - translator.suppCount(); } + private static String checkEncoding(String s) { + if (s == null || CharMatcher.ASCII.matchesAllOf(s)) { return s; } + return codecs.PyUnicode_EncodeASCII(s, s.length(), null); + } + @ExposedNew final static PyObject unicode_new(PyNewWrapper new_, boolean init, PyType subtype, PyObject[] args, String[] keywords) { @@ -585,8 +591,8 @@ new ArgParser("unicode", args, keywords, new String[] {"string", "encoding", "errors"}, 0); PyObject S = ap.getPyObject(0, null); - String encoding = ap.getString(1, null); - String errors = ap.getString(2, null); + String encoding = checkEncoding(ap.getString(1, null)); + String errors = checkEncoding(ap.getString(2, null)); if (new_.for_type == subtype) { if (S == null) { return new PyUnicode(""); diff --git a/src/org/python/modules/_json/Encoder.java b/src/org/python/modules/_json/Encoder.java --- a/src/org/python/modules/_json/Encoder.java +++ b/src/org/python/modules/_json/Encoder.java @@ -58,7 +58,7 @@ public PyObject __call__(PyObject obj, PyObject indent_level) { PyList rval = new PyList(); - listencode_obj(rval, obj, 0); + encode_obj(rval, obj, 0); return rval; } @@ -86,7 +86,19 @@ return (PyString) encoder.__call__(obj); } - private void listencode_obj(PyList rval, PyObject obj, int indent_level) { + private PyObject checkCircularReference(PyObject obj) { + PyObject ident = null; + if (markers != null) { + ident = Py.newInteger(Py.id(obj)); + if (markers.__contains__(ident)) { + throw Py.ValueError("Circular reference detected"); + } + markers.__setitem__(ident, obj); + } + return ident; + } + + private void encode_obj(PyList rval, PyObject obj, int indent_level) { /* Encode Python object obj to a JSON term, rval is a PyList */ if (obj == Py.None) { rval.append(new PyString("null")); @@ -101,64 +113,31 @@ } else if (obj instanceof PyFloat) { rval.append(encode_float(obj)); } else if (obj instanceof PyList || obj instanceof PyTuple) { - listencode_list(rval, obj, indent_level); + encode_list(rval, obj, indent_level); } else if (obj instanceof PyDictionary) { - listencode_dict(rval, (PyDictionary) obj, indent_level); + encode_dict(rval, (PyDictionary) obj, indent_level); } else { - PyObject ident = null; - if (markers != null) { - boolean contained = false; - try { - contained = markers.__contains__(obj); - } catch (PyException pye) { - // ignore objects that are not hashable, so they can be - // potentially serialized with defaultfn - if (!pye.match(Py.TypeError)) throw pye; - } - if (contained) { - throw Py.ValueError("Circular reference detected"); - } - ident = Py.newInteger(Py.id(obj)); - markers.__setitem__(ident, obj); - } + PyObject ident = checkCircularReference(obj); if (defaultfn == Py.None) { throw Py.TypeError(String.format(".80s is not JSON serializable", obj.__repr__())); } - PyObject newobj; - try { - newobj = defaultfn.__call__(obj); - } catch (StackOverflowError e) { - if (markers == Py.None) { - throw e; - } else { - throw Py.ValueError("Stack overflow in JSON serialization"); - } - } - listencode_obj(rval, newobj, indent_level); + PyObject newobj = defaultfn.__call__(obj); + encode_obj(rval, newobj, indent_level); if (ident != null) { markers.__delitem__(ident); } } } - private void listencode_dict(PyList rval, PyDictionary dct, int indent_level) { + private void encode_dict(PyList rval, PyDictionary dct, int indent_level) { /* Encode Python dict dct a JSON term */ - - PyObject ident = null; - if (dct.__len__() == 0) { rval.append(new PyString("{}")); return; } - if (markers != null) { - ident = Py.newInteger(Py.id(dct)); - if (markers.__contains__(ident)) { - throw Py.ValueError("Circular reference detected"); - } - markers.__setitem__(ident, dct); - } + PyObject ident = checkCircularReference(dct); rval.append(new PyString("{")); /* TODO: C speedup not implemented for sort_keys */ @@ -193,7 +172,7 @@ PyString encoded = encode_string(kstr); rval.append(encoded); rval.append(key_separator); - listencode_obj(rval, value, indent_level); + encode_obj(rval, value, indent_level); idx += 1; } @@ -204,17 +183,8 @@ } - private void listencode_list(PyList rval, PyObject seq, int indent_level) { - PyObject ident = null; - - if (markers != null) { - ident = Py.newInteger(Py.id(seq)); - if (markers.__contains__(ident)) { - throw Py.ValueError("Circular reference detected"); - } - markers.__setitem__(ident, seq); - } - + private void encode_list(PyList rval, PyObject seq, int indent_level) { + PyObject ident = checkCircularReference(seq); rval.append(new PyString("[")); int i = 0; @@ -222,7 +192,7 @@ if (i > 0) { rval.append(item_separator); } - listencode_obj(rval, obj, indent_level); + encode_obj(rval, obj, indent_level); i++; } diff --git a/src/org/python/modules/_json/Scanner.java b/src/org/python/modules/_json/Scanner.java --- a/src/org/python/modules/_json/Scanner.java +++ b/src/org/python/modules/_json/Scanner.java @@ -7,6 +7,8 @@ import org.python.core.PyString; import org.python.core.PyTuple; import org.python.core.PyType; +import org.python.core.PyUnicode; +import org.python.core.codecs; import org.python.expose.ExposedGet; import org.python.expose.ExposedType; @@ -29,8 +31,7 @@ public Scanner(PyObject context) { super(); - PyObject encoding_obj = context.__getattr__("encoding"); - encoding = encoding_obj == Py.None ? "utf-8" : context.__getattr__("encoding").asString(); + encoding = _castString(context.__getattr__("encoding"), "utf-8"); strict = context.__getattr__("strict").__nonzero__(); object_hook = context.__getattr__("object_hook"); pairs_hook = context.__getattr__("object_pairs_hook"); @@ -47,6 +48,20 @@ return (c == ' ') || (c == '\t') || (c == '\n') || (c == '\r'); } + private static String _castString(PyObject pystr, String defaultValue) { + // Jython used to treat String as equivalent to PyString, or maybe PyUnicode, as + // it made sense. We need to be more careful now! Insert this cast check as necessary + // to ensure the appropriate compliance. + if (pystr == Py.None) { + return defaultValue; + } + if (!(pystr instanceof PyString)) { + throw Py.TypeError("encoding is not a string"); + } + String s = pystr.toString(); + return codecs.PyUnicode_EncodeASCII(s, s.length(), null); + } + static PyTuple valIndex(PyObject obj, int i) { return new PyTuple(obj, Py.newInteger(i)); } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 24 03:28:06 2014 From: jython-checkins at python.org (jim.baker) Date: Wed, 24 Dec 2014 02:28:06 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Support_systems_with_unicod?= =?utf-8?q?e_paths_in_test=5Frunpy?= Message-ID: <20141224022800.102379.46732@psf.io> https://hg.python.org/jython/rev/96523e6f90fd changeset: 7468:96523e6f90fd user: Jim Baker date: Tue Dec 23 18:23:42 2014 -0700 summary: Support systems with unicode paths in test_runpy test_zipfile_error in test_runpy was not suporting the exact error text in the ImportError, due to Unicode paths being returned by os.path.realpath on OSX and Windows. files: Lib/test/test_runpy.py | 402 +++++++++++++++++++++++++++++ 1 files changed, 402 insertions(+), 0 deletions(-) diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py new file mode 100644 --- /dev/null +++ b/Lib/test/test_runpy.py @@ -0,0 +1,402 @@ +# Test the runpy module +import unittest +import os +import os.path +import sys +import re +import tempfile +from test.test_support import verbose, run_unittest, forget +from test.script_helper import (temp_dir, make_script, compile_script, + make_pkg, make_zip_script, make_zip_pkg) + + +from runpy import _run_code, _run_module_code, run_module, run_path +# Note: This module can't safely test _run_module_as_main as it +# runs its tests in the current process, which would mess with the +# real __main__ module (usually test.regrtest) +# See test_cmd_line_script for a test that executes that code path + +# Set up the test code and expected results + +class RunModuleCodeTest(unittest.TestCase): + """Unit tests for runpy._run_code and runpy._run_module_code""" + + expected_result = ["Top level assignment", "Lower level reference"] + test_source = ( + "# Check basic code execution\n" + "result = ['Top level assignment']\n" + "def f():\n" + " result.append('Lower level reference')\n" + "f()\n" + "# Check the sys module\n" + "import sys\n" + "run_argv0 = sys.argv[0]\n" + "run_name_in_sys_modules = __name__ in sys.modules\n" + "if run_name_in_sys_modules:\n" + " module_in_sys_modules = globals() is sys.modules[__name__].__dict__\n" + "# Check nested operation\n" + "import runpy\n" + "nested = runpy._run_module_code('x=1\\n', mod_name='')\n" + ) + + def test_run_code(self): + saved_argv0 = sys.argv[0] + d = _run_code(self.test_source, {}) + self.assertEqual(d["result"], self.expected_result) + self.assertIs(d["__name__"], None) + self.assertIs(d["__file__"], None) + self.assertIs(d["__loader__"], None) + self.assertIs(d["__package__"], None) + self.assertIs(d["run_argv0"], saved_argv0) + self.assertNotIn("run_name", d) + self.assertIs(sys.argv[0], saved_argv0) + + def test_run_module_code(self): + initial = object() + name = "" + file = "Some other nonsense" + loader = "Now you're just being silly" + package = '' # Treat as a top level module + d1 = dict(initial=initial) + saved_argv0 = sys.argv[0] + d2 = _run_module_code(self.test_source, + d1, + name, + file, + loader, + package) + self.assertNotIn("result", d1) + self.assertIs(d2["initial"], initial) + self.assertEqual(d2["result"], self.expected_result) + self.assertEqual(d2["nested"]["x"], 1) + self.assertIs(d2["__name__"], name) + self.assertTrue(d2["run_name_in_sys_modules"]) + self.assertTrue(d2["module_in_sys_modules"]) + self.assertIs(d2["__file__"], file) + self.assertIs(d2["run_argv0"], file) + self.assertIs(d2["__loader__"], loader) + self.assertIs(d2["__package__"], package) + self.assertIs(sys.argv[0], saved_argv0) + self.assertNotIn(name, sys.modules) + + +class RunModuleTest(unittest.TestCase): + """Unit tests for runpy.run_module""" + + def expect_import_error(self, mod_name): + try: + run_module(mod_name) + except ImportError: + pass + else: + self.fail("Expected import error for " + mod_name) + + def test_invalid_names(self): + # Builtin module + self.expect_import_error("sys") + # Non-existent modules + self.expect_import_error("sys.imp.eric") + self.expect_import_error("os.path.half") + self.expect_import_error("a.bee") + self.expect_import_error(".howard") + self.expect_import_error("..eaten") + # Package without __main__.py + self.expect_import_error("multiprocessing") + + def test_library_module(self): + run_module("runpy") + + def _add_pkg_dir(self, pkg_dir): + os.mkdir(pkg_dir) + pkg_fname = os.path.join(pkg_dir, "__init__"+os.extsep+"py") + pkg_file = open(pkg_fname, "w") + pkg_file.close() + return pkg_fname + + def _make_pkg(self, source, depth, mod_base="runpy_test"): + pkg_name = "__runpy_pkg__" + test_fname = mod_base+os.extsep+"py" + pkg_dir = sub_dir = tempfile.mkdtemp() + if verbose: print " Package tree in:", sub_dir + sys.path.insert(0, pkg_dir) + if verbose: print " Updated sys.path:", sys.path[0] + for i in range(depth): + sub_dir = os.path.join(sub_dir, pkg_name) + pkg_fname = self._add_pkg_dir(sub_dir) + if verbose: print " Next level in:", sub_dir + if verbose: print " Created:", pkg_fname + mod_fname = os.path.join(sub_dir, test_fname) + mod_file = open(mod_fname, "w") + mod_file.write(source) + mod_file.close() + if verbose: print " Created:", mod_fname + mod_name = (pkg_name+".")*depth + mod_base + return pkg_dir, mod_fname, mod_name + + def _del_pkg(self, top, depth, mod_name): + for entry in list(sys.modules): + if entry.startswith("__runpy_pkg__"): + del sys.modules[entry] + if verbose: print " Removed sys.modules entries" + del sys.path[0] + if verbose: print " Removed sys.path entry" + for root, dirs, files in os.walk(top, topdown=False): + for name in files: + try: + os.remove(os.path.join(root, name)) + except OSError, ex: + if verbose: print ex # Persist with cleaning up + for name in dirs: + fullname = os.path.join(root, name) + try: + os.rmdir(fullname) + except OSError, ex: + if verbose: print ex # Persist with cleaning up + try: + os.rmdir(top) + if verbose: print " Removed package tree" + except OSError, ex: + if verbose: print ex # Persist with cleaning up + + def _check_module(self, depth): + pkg_dir, mod_fname, mod_name = ( + self._make_pkg("x=1\n", depth)) + forget(mod_name) + try: + if verbose: print "Running from source:", mod_name + d1 = run_module(mod_name) # Read from source + self.assertIn("x", d1) + self.assertTrue(d1["x"] == 1) + del d1 # Ensure __loader__ entry doesn't keep file open + __import__(mod_name) + os.remove(mod_fname) + if verbose: print "Running from compiled:", mod_name + d2 = run_module(mod_name) # Read from bytecode + self.assertIn("x", d2) + self.assertTrue(d2["x"] == 1) + del d2 # Ensure __loader__ entry doesn't keep file open + finally: + self._del_pkg(pkg_dir, depth, mod_name) + if verbose: print "Module executed successfully" + + def _check_package(self, depth): + pkg_dir, mod_fname, mod_name = ( + self._make_pkg("x=1\n", depth, "__main__")) + pkg_name, _, _ = mod_name.rpartition(".") + forget(mod_name) + try: + if verbose: print "Running from source:", pkg_name + d1 = run_module(pkg_name) # Read from source + self.assertIn("x", d1) + self.assertTrue(d1["x"] == 1) + del d1 # Ensure __loader__ entry doesn't keep file open + __import__(mod_name) + os.remove(mod_fname) + if verbose: print "Running from compiled:", pkg_name + d2 = run_module(pkg_name) # Read from bytecode + self.assertIn("x", d2) + self.assertTrue(d2["x"] == 1) + del d2 # Ensure __loader__ entry doesn't keep file open + finally: + self._del_pkg(pkg_dir, depth, pkg_name) + if verbose: print "Package executed successfully" + + def _add_relative_modules(self, base_dir, source, depth): + if depth <= 1: + raise ValueError("Relative module test needs depth > 1") + pkg_name = "__runpy_pkg__" + module_dir = base_dir + for i in range(depth): + parent_dir = module_dir + module_dir = os.path.join(module_dir, pkg_name) + # Add sibling module + sibling_fname = os.path.join(module_dir, "sibling"+os.extsep+"py") + sibling_file = open(sibling_fname, "w") + sibling_file.close() + if verbose: print " Added sibling module:", sibling_fname + # Add nephew module + uncle_dir = os.path.join(parent_dir, "uncle") + self._add_pkg_dir(uncle_dir) + if verbose: print " Added uncle package:", uncle_dir + cousin_dir = os.path.join(uncle_dir, "cousin") + self._add_pkg_dir(cousin_dir) + if verbose: print " Added cousin package:", cousin_dir + nephew_fname = os.path.join(cousin_dir, "nephew"+os.extsep+"py") + nephew_file = open(nephew_fname, "w") + nephew_file.close() + if verbose: print " Added nephew module:", nephew_fname + + def _check_relative_imports(self, depth, run_name=None): + contents = r"""\ +from __future__ import absolute_import +from . import sibling +from ..uncle.cousin import nephew +""" + pkg_dir, mod_fname, mod_name = ( + self._make_pkg(contents, depth)) + try: + self._add_relative_modules(pkg_dir, contents, depth) + pkg_name = mod_name.rpartition('.')[0] + if verbose: print "Running from source:", mod_name + d1 = run_module(mod_name, run_name=run_name) # Read from source + self.assertIn("__package__", d1) + self.assertTrue(d1["__package__"] == pkg_name) + self.assertIn("sibling", d1) + self.assertIn("nephew", d1) + del d1 # Ensure __loader__ entry doesn't keep file open + __import__(mod_name) + os.remove(mod_fname) + if verbose: print "Running from compiled:", mod_name + d2 = run_module(mod_name, run_name=run_name) # Read from bytecode + self.assertIn("__package__", d2) + self.assertTrue(d2["__package__"] == pkg_name) + self.assertIn("sibling", d2) + self.assertIn("nephew", d2) + del d2 # Ensure __loader__ entry doesn't keep file open + finally: + self._del_pkg(pkg_dir, depth, mod_name) + if verbose: print "Module executed successfully" + + def test_run_module(self): + for depth in range(4): + if verbose: print "Testing package depth:", depth + self._check_module(depth) + + def test_run_package(self): + for depth in range(1, 4): + if verbose: print "Testing package depth:", depth + self._check_package(depth) + + def test_explicit_relative_import(self): + for depth in range(2, 5): + if verbose: print "Testing relative imports at depth:", depth + self._check_relative_imports(depth) + + def test_main_relative_import(self): + for depth in range(2, 5): + if verbose: print "Testing main relative imports at depth:", depth + self._check_relative_imports(depth, "__main__") + + +class RunPathTest(unittest.TestCase): + """Unit tests for runpy.run_path""" + # Based on corresponding tests in test_cmd_line_script + + test_source = """\ +# Script may be run with optimisation enabled, so don't rely on assert +# statements being executed +def assertEqual(lhs, rhs): + if lhs != rhs: + raise AssertionError('%r != %r' % (lhs, rhs)) +def assertIs(lhs, rhs): + if lhs is not rhs: + raise AssertionError('%r is not %r' % (lhs, rhs)) +# Check basic code execution +result = ['Top level assignment'] +def f(): + result.append('Lower level reference') +f() +assertEqual(result, ['Top level assignment', 'Lower level reference']) +# Check the sys module +import sys +assertIs(globals(), sys.modules[__name__].__dict__) +argv0 = sys.argv[0] +""" + + def _make_test_script(self, script_dir, script_basename, source=None): + if source is None: + source = self.test_source + return make_script(script_dir, script_basename, source) + + def _check_script(self, script_name, expected_name, expected_file, + expected_argv0, expected_package): + result = run_path(script_name) + self.assertEqual(result["__name__"], expected_name) + self.assertEqual(result["__file__"], expected_file) + self.assertIn("argv0", result) + self.assertEqual(result["argv0"], expected_argv0) + self.assertEqual(result["__package__"], expected_package) + + def _check_import_error(self, script_name, msg): + msg = re.escape(msg) + self.assertRaisesRegexp(ImportError, msg, run_path, script_name) + + def test_basic_script(self): + with temp_dir() as script_dir: + mod_name = 'script' + script_name = self._make_test_script(script_dir, mod_name) + self._check_script(script_name, "", script_name, + script_name, None) + + def test_script_compiled(self): + with temp_dir() as script_dir: + mod_name = 'script' + script_name = self._make_test_script(script_dir, mod_name) + compiled_name = compile_script(script_name) + os.remove(script_name) + self._check_script(compiled_name, "", compiled_name, + compiled_name, None) + + def test_directory(self): + with temp_dir() as script_dir: + mod_name = '__main__' + script_name = self._make_test_script(script_dir, mod_name) + self._check_script(script_dir, "", script_name, + script_dir, '') + + def test_directory_compiled(self): + with temp_dir() as script_dir: + mod_name = '__main__' + script_name = self._make_test_script(script_dir, mod_name) + compiled_name = compile_script(script_name) + os.remove(script_name) + self._check_script(script_dir, "", compiled_name, + script_dir, '') + + def test_directory_error(self): + with temp_dir() as script_dir: + mod_name = 'not_main' + script_name = self._make_test_script(script_dir, mod_name) + msg = "can't find '__main__' module in %r" % script_dir + self._check_import_error(script_dir, msg) + + def test_zipfile(self): + with temp_dir() as script_dir: + mod_name = '__main__' + script_name = self._make_test_script(script_dir, mod_name) + zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name) + self._check_script(zip_name, "", fname, zip_name, '') + + def test_zipfile_compiled(self): + with temp_dir() as script_dir: + mod_name = '__main__' + script_name = self._make_test_script(script_dir, mod_name) + compiled_name = compile_script(script_name) + zip_name, fname = make_zip_script(script_dir, 'test_zip', compiled_name) + self._check_script(zip_name, "", fname, zip_name, '') + + def test_zipfile_error(self): + with temp_dir() as script_dir: + mod_name = 'not_main' + script_name = self._make_test_script(script_dir, mod_name) + zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name) + msg = "can't find '__main__' module in '%s'" % zip_name + self._check_import_error(zip_name, msg) + + def test_main_recursion_error(self): + with temp_dir() as script_dir, temp_dir() as dummy_dir: + mod_name = '__main__' + source = ("import runpy\n" + "runpy.run_path(%r)\n") % dummy_dir + script_name = self._make_test_script(script_dir, mod_name, source) + zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name) + msg = "recursion depth exceeded" + self.assertRaisesRegexp(RuntimeError, msg, run_path, zip_name) + + + +def test_main(): + run_unittest(RunModuleCodeTest, RunModuleTest, RunPathTest) + +if __name__ == "__main__": + test_main() -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 24 18:36:00 2014 From: jython-checkins at python.org (jim.baker) Date: Wed, 24 Dec 2014 17:36:00 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Emit_warning_or_raise_excep?= =?utf-8?q?tion_for_nonexistent_parent_module?= Message-ID: <20141224173547.71909.56441@psf.io> https://hg.python.org/jython/rev/8d6bf085fbeb changeset: 7469:8d6bf085fbeb user: Jim Baker date: Wed Dec 24 10:35:40 2014 -0700 summary: Emit warning or raise exception for nonexistent parent module Python expects __package__, when set, to correspond to the parent module; this is then used to support relative imports. Ordinarily this setting is done as part of the import process, however, it is possible for user code to change or set, including to a nonexistent module. If the parent doesn't exists, we follow what was done for http://bugs.python.org/issue3221: * For absolute imports, emit a RuntimeWarning, since it's possible that __package__ is being used for other purposes in code written prior to Python 2.6 * For relative imports, raise a SystemError Ports the corresponding fix from bug 3221 in the get_parent function in import.c files: src/org/python/core/imp.java | 15 +++++++++++++++ 1 files changed, 15 insertions(+), 0 deletions(-) diff --git a/src/org/python/core/imp.java b/src/org/python/core/imp.java --- a/src/org/python/core/imp.java +++ b/src/org/python/core/imp.java @@ -676,6 +676,8 @@ */ private static String get_parent(PyObject dict, int level) { String modname; + int orig_level = level; + if ((dict == null && level == -1) || level == 0) { // try an absolute import return null; @@ -726,6 +728,19 @@ } modname = modname.substring(0, dot); } + + if (Py.getSystemState().modules.__finditem__(modname) == null) { + if (orig_level < 1) { + Py.warning(Py.RuntimeWarning, + String.format( + "Parent module '%.200s' not found " + + "while handling absolute import", modname)); + } else { + throw Py.SystemError(String.format( + "Parent module '%.200s' not loaded, " + + "cannot perform relative import", modname)); + } + } return modname.intern(); } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 24 22:17:16 2014 From: jython-checkins at python.org (jim.baker) Date: Wed, 24 Dec 2014 21:17:16 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Retain_target_file_name=2C_?= =?utf-8?q?to_be_used_in_code_objects_if_source_not_available?= Message-ID: <20141224211706.22210.30182@psf.io> https://hg.python.org/jython/rev/ae190b66bf4c changeset: 7470:ae190b66bf4c user: Jim Baker date: Wed Dec 24 14:17:00 2014 -0700 summary: Retain target file name, to be used in code objects if source not available If the compiled (*.py$class) file is available without the companion source file, code objects when loaded now set co_filename to the target filename originally associated with that file, regardless of what the name of the file is now. Note as with CPython, modules can be compiled with a different target filename than their source filename. Bumps bytecode magic (org.python.core.imp.APIVersion) to 35. Adds a new retained annotation, org.python.compiler.Filename, to annotate compiled modules, as implemented using PyFunctionTable. This change can be seen in the following decompilation of a compiled Python module: import org.python.compiler.*; import org.python.core.*; @APIVersion(35) @MTime(1419447569000L) @Filename("foo.py") public class foo$py extends PyFunctionTable implements PyRunnable { ... } Fixes the last failing test case in test_import, test_module_without_source files: src/org/python/compiler/ClassFile.java | 5 +- src/org/python/compiler/Filename.java | 9 + src/org/python/core/AnnotationReader.java | 14 ++- src/org/python/core/imp.java | 63 +++++++++- 4 files changed, 78 insertions(+), 13 deletions(-) diff --git a/src/org/python/compiler/ClassFile.java b/src/org/python/compiler/ClassFile.java --- a/src/org/python/compiler/ClassFile.java +++ b/src/org/python/compiler/ClassFile.java @@ -189,7 +189,7 @@ av.visitEnd(); } } - + public void write(OutputStream stream) throws IOException { cw.visit(Opcodes.V1_5, Opcodes.ACC_PUBLIC + Opcodes.ACC_SUPER, this.name, null, this.superclass, interfaces); AnnotationVisitor av = cw.visitAnnotation("Lorg/python/compiler/APIVersion;", true); @@ -203,6 +203,9 @@ av.visitEnd(); if (sfilename != null) { + av = cw.visitAnnotation("Lorg/python/compiler/Filename;", true); + av.visit("value", sfilename); + av.visitEnd(); cw.visitSource(sfilename, null); } endClassAnnotations(); diff --git a/src/org/python/compiler/Filename.java b/src/org/python/compiler/Filename.java new file mode 100644 --- /dev/null +++ b/src/org/python/compiler/Filename.java @@ -0,0 +1,9 @@ +package org.python.compiler; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + + at Retention(RetentionPolicy.RUNTIME) +public @interface Filename { + String value(); +} diff --git a/src/org/python/core/AnnotationReader.java b/src/org/python/core/AnnotationReader.java --- a/src/org/python/core/AnnotationReader.java +++ b/src/org/python/core/AnnotationReader.java @@ -24,9 +24,11 @@ private boolean nextVisitIsVersion = false; private boolean nextVisitIsMTime = false; + private boolean nextVisitIsFilename = false; private int version = -1; private long mtime = -1; + private String filename = null; /** * Reads the classfile bytecode in data and to extract the version. @@ -50,6 +52,7 @@ public AnnotationVisitor visitAnnotation(String desc, boolean visible) { nextVisitIsVersion = desc.equals("Lorg/python/compiler/APIVersion;"); nextVisitIsMTime = desc.equals("Lorg/python/compiler/MTime;"); + nextVisitIsFilename = desc.equals("Lorg/python/compiler/Filename;"); return new AnnotationVisitor(Opcodes.ASM4) { public void visit(String name, Object value) { @@ -58,8 +61,11 @@ nextVisitIsVersion = false; } else if (nextVisitIsMTime) { mtime = (Long)value; - nextVisitIsVersion = false; - } + nextVisitIsMTime = false; + } else if (nextVisitIsFilename) { + filename = (String)value; + nextVisitIsFilename = false; + } } }; } @@ -71,4 +77,8 @@ public long getMTime() { return mtime; } + + public String getFilename() { + return filename; + } } diff --git a/src/org/python/core/imp.java b/src/org/python/core/imp.java --- a/src/org/python/core/imp.java +++ b/src/org/python/core/imp.java @@ -27,7 +27,7 @@ private static final String UNKNOWN_SOURCEFILE = ""; - private static final int APIVersion = 34; + private static final int APIVersion = 35; public static final int NO_MTIME = -1; @@ -35,6 +35,34 @@ // imports unless `from __future__ import absolute_import` public static final int DEFAULT_LEVEL = -1; + public static class CodeData { + private final byte[] bytes; + private final long mtime; + private final String filename; + + public CodeData(byte[] bytes, long mtime, String filename) { + this.bytes = bytes; + this.mtime = mtime; + this.filename = filename; + } + + public byte[] getBytes() { + return bytes; + } + + public long getMTime() { + return mtime; + } + + public String getFilename() { + return filename; + } + } + + public static enum CodeImport { + source, compiled_only; + } + /** A non-empty fromlist for __import__'ing sub-modules. */ private static final PyObject nonEmptyFromlist = new PyTuple(Py.newString("__doc__")); @@ -174,9 +202,14 @@ static PyObject createFromPyClass(String name, InputStream fp, boolean testing, String sourceName, String compiledName, long mtime) { - byte[] data = null; + return createFromPyClass(name, fp, testing, sourceName, compiledName, mtime, CodeImport.source); + } + + static PyObject createFromPyClass(String name, InputStream fp, boolean testing, + String sourceName, String compiledName, long mtime, CodeImport source) { + CodeData data = null; try { - data = readCode(name, fp, testing, mtime); + data = readCodeData(name, fp, testing, mtime); } catch (IOException ioe) { if (!testing) { throw Py.ImportError(ioe.getMessage() + "[name=" + name + ", source=" + sourceName @@ -188,7 +221,8 @@ } PyCode code; try { - code = BytecodeLoader.makeCode(name + "$py", data, sourceName); + code = BytecodeLoader.makeCode(name + "$py", data.getBytes(), + source == CodeImport.compiled_only ? data.getFilename() : sourceName); } catch (Throwable t) { if (testing) { return null; @@ -199,7 +233,6 @@ Py.writeComment(IMPORT_LOG, String.format("import %s # precompiled from %s", name, compiledName)); - return createFromCode(name, code, compiledName); } @@ -208,6 +241,14 @@ } public static byte[] readCode(String name, InputStream fp, boolean testing, long mtime) throws IOException { + return readCodeData(name, fp, testing, mtime).getBytes(); + } + + public static CodeData readCodeData(String name, InputStream fp, boolean testing) throws IOException { + return readCodeData(name, fp, testing, NO_MTIME); + } + + public static CodeData readCodeData(String name, InputStream fp, boolean testing, long mtime) throws IOException { byte[] data = readBytes(fp); int api; AnnotationReader ar = new AnnotationReader(data); @@ -226,7 +267,7 @@ return null; } } - return data; + return new CodeData(data, mtime, ar.getFilename()); } public static byte[] compileSource(String name, File file) { @@ -582,8 +623,9 @@ Py.writeDebug(IMPORT_LOG, "trying precompiled " + compiledFile.getPath()); long classTime = compiledFile.lastModified(); if (classTime >= pyTime) { - PyObject ret = createFromPyClass(modName, makeStream(compiledFile), true, - displaySourceName, displayCompiledName, pyTime); + PyObject ret = createFromPyClass( + modName, makeStream(compiledFile), true, + displaySourceName, displayCompiledName, pyTime); if (ret != null) { return ret; } @@ -598,8 +640,9 @@ // If no source, try loading precompiled Py.writeDebug(IMPORT_LOG, "trying precompiled with no source " + compiledFile.getPath()); if (compiledFile.isFile() && caseok(compiledFile, compiledName)) { - return createFromPyClass(modName, makeStream(compiledFile), true, displaySourceName, - displayCompiledName); + return createFromPyClass( + modName, makeStream(compiledFile), true, displaySourceName, + displayCompiledName, NO_MTIME, CodeImport.compiled_only); } } catch (SecurityException e) { // ok -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Fri Dec 26 15:19:29 2014 From: jython-checkins at python.org (jim.baker) Date: Fri, 26 Dec 2014 14:19:29 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Do_not_emit_RuntimeWarning_?= =?utf-8?q?for_absolute_imports_with_empty_parent?= Message-ID: <20141226141928.71917.68142@psf.io> https://hg.python.org/jython/rev/01dd1d307b41 changeset: 7471:01dd1d307b41 user: Jim Baker date: Fri Dec 26 07:19:23 2014 -0700 summary: Do not emit RuntimeWarning for absolute imports with empty parent files: src/org/python/core/imp.java | 10 ++++++---- 1 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/org/python/core/imp.java b/src/org/python/core/imp.java --- a/src/org/python/core/imp.java +++ b/src/org/python/core/imp.java @@ -774,10 +774,12 @@ if (Py.getSystemState().modules.__finditem__(modname) == null) { if (orig_level < 1) { - Py.warning(Py.RuntimeWarning, - String.format( - "Parent module '%.200s' not found " + - "while handling absolute import", modname)); + if (modname.length() > 0) { + Py.warning(Py.RuntimeWarning, + String.format( + "Parent module '%.200s' not found " + + "while handling absolute import", modname)); + } } else { throw Py.SystemError(String.format( "Parent module '%.200s' not loaded, " + -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sat Dec 27 16:59:10 2014 From: jython-checkins at python.org (jim.baker) Date: Sat, 27 Dec 2014 15:59:10 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Added_missing_Python_method?= =?utf-8?q?s_for_proxied_java=2Eutil=2EList_objects?= Message-ID: <20141227155907.120051.25848@psf.io> https://hg.python.org/jython/rev/03d04033c305 changeset: 7472:03d04033c305 user: Santoso Wijaya date: Fri Dec 26 07:46:03 2014 -0700 summary: Added missing Python methods for proxied java.util.List objects Squashed https://bitbucket.org/santa4nt/jython/commits/branch/fix-issue-2215 as a single changeset Part of the fix for http://bugs.jython.org/issue2215 files: Lib/test/test_list.py | 26 +- Lib/test/test_list_jy.py | 16 +- src/org/python/core/PyJavaType.java | 325 +++++++++++++++- 3 files changed, 341 insertions(+), 26 deletions(-) diff --git a/Lib/test/test_list.py b/Lib/test/test_list.py --- a/Lib/test/test_list.py +++ b/Lib/test/test_list.py @@ -2,18 +2,19 @@ from test import test_support, list_tests class ListTest(list_tests.CommonTest): + type2test = list def test_basic(self): - self.assertEqual(list([]), []) + self.assertEqual(self.type2test([]), []) l0_3 = [0, 1, 2, 3] - l0_3_bis = list(l0_3) + l0_3_bis = self.type2test(l0_3) self.assertEqual(l0_3, l0_3_bis) self.assertTrue(l0_3 is not l0_3_bis) - self.assertEqual(list(()), []) - self.assertEqual(list((0, 1, 2, 3)), [0, 1, 2, 3]) - self.assertEqual(list(''), []) - self.assertEqual(list('spam'), ['s', 'p', 'a', 'm']) + self.assertEqual(self.type2test(()), []) + self.assertEqual(self.type2test((0, 1, 2, 3)), [0, 1, 2, 3]) + self.assertEqual(self.type2test(''), []) + self.assertEqual(self.type2test('spam'), ['s', 'p', 'a', 'm']) #FIXME: too brutal for us ATM. if not test_support.is_jython: @@ -41,20 +42,21 @@ def test_truth(self): super(ListTest, self).test_truth() - self.assertTrue(not []) - self.assertTrue([42]) + self.assertTrue(not self.type2test([])) + self.assertTrue(self.type2test([42])) def test_identity(self): self.assertTrue([] is not []) + self.assertTrue(self.type2test([]) is not self.type2test([])) def test_len(self): super(ListTest, self).test_len() - self.assertEqual(len([]), 0) - self.assertEqual(len([0]), 1) - self.assertEqual(len([0, 1, 2]), 3) + self.assertEqual(len(self.type2test([])), 0) + self.assertEqual(len(self.type2test([0])), 1) + self.assertEqual(len(self.type2test([0, 1, 2])), 3) def test_overflow(self): - lst = [4, 5, 6, 7] + lst = self.type2test([4, 5, 6, 7]) n = int((sys.maxint*2+2) // len(lst)) def mul(a, b): return a * b def imul(a, b): a *= b diff --git a/Lib/test/test_list_jy.py b/Lib/test/test_list_jy.py --- a/Lib/test/test_list_jy.py +++ b/Lib/test/test_list_jy.py @@ -3,6 +3,7 @@ import threading import time from test import test_support +import test_list if test_support.is_jython: from java.util import ArrayList @@ -209,10 +210,23 @@ self.assertEqual(a, expected4) +class JavaListTestCase(test_list.ListTest): + + type2test = ArrayList + + def test_extend_java_ArrayList(self): + jl = ArrayList([]) + jl.extend([1,2]) + self.assertEqual(jl, ArrayList([1,2])) + jl.extend(ArrayList([3,4])) + self.assertEqual(jl, [1,2,3,4]) + + def test_main(): test_support.run_unittest(ListTestCase, ThreadSafetyTestCase, - ExtendedSliceTestCase) + ExtendedSliceTestCase, + JavaListTestCase) if __name__ == "__main__": test_main() diff --git a/src/org/python/core/PyJavaType.java b/src/org/python/core/PyJavaType.java --- a/src/org/python/core/PyJavaType.java +++ b/src/org/python/core/PyJavaType.java @@ -15,17 +15,9 @@ import java.lang.reflect.Member; import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Enumeration; -import java.util.EventListener; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.Queue; +import java.util.*; +import com.google.common.collect.Lists; import org.python.core.util.StringUtil; import org.python.util.Generic; @@ -898,6 +890,10 @@ super(name, numArgs); } + protected ListMethod(String name, int minArgs, int maxArgs) { + super(name, minArgs, maxArgs); + } + protected List asList(){ return (List)self.getJavaProxy(); } @@ -1442,9 +1438,299 @@ return Py.None; } }; - collectionProxies.put(List.class, new PyBuiltinMethod[] {listGetProxy, - listSetProxy, - listRemoveProxy}); + PyBuiltinMethodNarrow listEqProxy = new ListMethod("__eq__", 1) { + @Override + public PyObject __call__(PyObject other) { + List jList = asList(); + if (other.getType().isSubType(PyList.TYPE)) { + PyList oList = (PyList) other; + if (jList.size() != oList.size()) { + return Py.False; + } + for (int i = 0; i < jList.size(); i++) { + if (!Py.java2py(jList.get(i))._eq(oList.pyget(i)).__nonzero__()) { + return Py.False; + } + } + return Py.True; + } else { + Object oj = other.getJavaProxy(); + if (oj instanceof List) { + List oList = (List) oj; + if (jList.size() != oList.size()) { + return Py.False; + } + for (int i = 0; i < jList.size(); i++) { + if (!Py.java2py(jList.get(i))._eq( + Py.java2py(oList.get(i))).__nonzero__()) { + return Py.False; + } + } + return Py.True; + } else { + return null; + } + } + } + }; + PyBuiltinMethodNarrow listAppendProxy = new ListMethod("append", 1) { + @Override + public PyObject __call__(PyObject value) { + asList().add(value); + return Py.None; + } + }; + PyBuiltinMethodNarrow listExtendProxy = new ListMethod("extend", 1) { + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + if (obj instanceof Collection) { + jList.addAll((Collection) obj); + } else { + for (PyObject item : obj.asIterable()) { + jList.add(item); + } + } + return Py.None; + } + }; + PyBuiltinMethodNarrow listInsertProxy = new ListMethod("insert", 2) { + @Override + public PyObject __call__(PyObject index, PyObject object) { + List jlist = asList(); + ListIndexDelegate lid = new ListIndexDelegate(jlist); + int idx = lid.fixBoundIndex(index.asIndex()); + jlist.add(idx, object); + return Py.None; + } + }; + PyBuiltinMethodNarrow listPopProxy = new ListMethod("pop", 0, 1) { + @Override + public PyObject __call__() { + return __call__(Py.newInteger(-1)); + } + @Override + public PyObject __call__(PyObject index) { + List jlist = asList(); + if (jlist.isEmpty()) { + throw Py.IndexError("pop from empty list"); + } + ListIndexDelegate ldel = new ListIndexDelegate(jlist); + PyObject item = ldel.checkIdxAndFindItem(index.asInt()); + if (item == null) { + throw Py.IndexError("pop index out of range"); + } else { + ldel.checkIdxAndDelItem(index); + return item; + } + } + }; + PyBuiltinMethodNarrow listIndexProxy = new ListMethod("index", 1, 3) { + @Override + public PyObject __call__(PyObject object) { + return __call__(object, Py.newInteger(0), Py.newInteger(asList().size())); + } + @Override + public PyObject __call__(PyObject object, PyObject start) { + return __call__(object, start, Py.newInteger(asList().size())); + } + @Override + public PyObject __call__(PyObject object, PyObject start, PyObject end) { + List jlist = asList(); + ListIndexDelegate lid = new ListIndexDelegate(jlist); + int st = lid.fixBoundIndex(start.asInt()); + int en = lid.fixBoundIndex(end.asInt()); + for (int i = st; i < en; i++) { + Object jobj = jlist.get(i); + if (Py.java2py(jobj)._eq(object).__nonzero__()) { + return Py.newInteger(i); + } + } + throw Py.ValueError(object.toString() + " is not in list"); + } + }; + PyBuiltinMethodNarrow listCountProxy = new ListMethod("count", 1) { + @Override + public PyObject __call__(PyObject object) { + int count = 0; + List jlist = asList(); + for (int i = 0; i < jlist.size(); i++) { + Object jobj = jlist.get(i); + if (Py.java2py(jobj)._eq(object).__nonzero__()) { + ++count; + } + } + return Py.newInteger(count); + } + }; + PyBuiltinMethodNarrow listReverseProxy = new ListMethod("reverse", 0) { + @Override + public PyObject __call__() { + List jlist = asList(); + Collections.reverse(jlist); + return Py.None; + } + }; + PyBuiltinMethodNarrow listRemoveOverrideProxy = new ListMethod("remove", 1) { + @Override + public PyObject __call__(PyObject object) { + List jlist = asList(); + for (int i = 0; i < jlist.size(); i++) { + Object jobj = jlist.get(i); + if (Py.java2py(jobj)._eq(object).__nonzero__()) { + jlist.remove(i); + return Py.None; + } + } + throw Py.ValueError(object.toString() + " is not in list"); + } + }; + PyBuiltinMethodNarrow listRAddProxy = new ListMethod("__radd__", 1) { + @Override + public PyObject __call__(PyObject obj) { + // first, clone the self list + List jList = asList(); + List jClone; + try { + jClone = (List) jList.getClass().newInstance(); + } catch (IllegalAccessException e) { + throw Py.JavaError(e); + } catch (InstantiationException e) { + throw Py.JavaError(e); + } + for (Object entry : jList) { + jClone.add(entry); + } + + // then, extend it with elements from the other list + // (but, since this is reverse add, we are technically + // pre-pending the clone with elements from the other list) + if (obj instanceof Collection) { + jClone.addAll(0, (Collection) obj); + } else { + int i = 0; + for (PyObject item : obj.asIterable()) { + jClone.add(i, item); + i++; + } + } + + return Py.java2py(jClone); + } + }; + PyBuiltinMethodNarrow listIAddProxy = new ListMethod("__iadd__", 1) { + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + if (obj instanceof Collection) { + jList.addAll((Collection) obj); + } else { + for (PyObject item : obj.asIterable()) { + jList.add(item); + } + } + return self; + } + }; + PyBuiltinMethodNarrow listRMulProxy = new ListMethod("__rmul__", 1) { + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + int mult = obj.asInt(); + + List jClone; + try { + jClone = (List) jList.getClass().newInstance(); + } catch (IllegalAccessException e) { + throw Py.JavaError(e); + } catch (InstantiationException e) { + throw Py.JavaError(e); + } + + // anything below 0 multiplier, we return an empty list + if (mult > 0) { + // otherwise, extend it x times, where x is int-cast from obj + for (; mult > 0; mult--) { + for (Object entry : jList) { + jClone.add(entry); + } + } + } + + return Py.java2py(jClone); + } + }; + PyBuiltinMethodNarrow listMulProxy = new ListMethod("__mul__", 1) { + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + int mult = obj.asInt(); + + List jClone; + try { + jClone = (List) jList.getClass().newInstance(); + } catch (IllegalAccessException e) { + throw Py.JavaError(e); + } catch (InstantiationException e) { + throw Py.JavaError(e); + } + + // anything below 0 multiplier, we return an empty list + if (mult > 0) { + // otherwise, extend it x times, where x is int-cast from obj + for (; mult > 0; mult--) { + for (Object entry : jList) { + jClone.add(entry); + } + } + } + + return Py.java2py(jClone); + } + }; + PyBuiltinMethodNarrow listIMulProxy = new ListMethod("__imul__", 1) { + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + int mult = obj.asInt(); + + // anything below 0 multiplier, we clear the list + if (mult <= 0) { + jList.clear(); + } else { + // otherwise, extend it (in-place) x times, where x is int-cast from obj + int originalSize = jList.size(); + for (mult = mult - 1; mult > 0; mult--) { + for (int i = 0; i < originalSize; i++) { + jList.add(jList.get(i)); + } + } + } + + return self; + } + }; + collectionProxies.put(List.class, new PyBuiltinMethod[] { + listGetProxy, + listSetProxy, + listEqProxy, + listRemoveProxy, + listAppendProxy, + listExtendProxy, + listInsertProxy, + listPopProxy, + listIndexProxy, + listCountProxy, + listReverseProxy, + listRAddProxy, + listIAddProxy, + listRMulProxy, + listMulProxy, + listIMulProxy, + }); + postCollectionProxies.put(List.class, new PyBuiltinMethod[]{ + listRemoveOverrideProxy, + }); } return collectionProxies; } @@ -1503,6 +1789,19 @@ return list.size(); } + protected int fixBoundIndex(int index) { + int l = len(); + if (index < 0) { + index += l; + if (index < 0) { + index = 0; + } + } else if (index > l) { + index = l; + } + return index; + } + @Override public void setItem(int idx, PyObject value) { list.set(idx, value.__tojava__(Object.class)); -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sat Dec 27 16:59:10 2014 From: jython-checkins at python.org (jim.baker) Date: Sat, 27 Dec 2014 15:59:10 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Completes_proxying_of_java?= =?utf-8?q?=2Eutil=2EList_objects?= Message-ID: <20141227155907.22210.43207@psf.io> https://hg.python.org/jython/rev/ae2a1efe4192 changeset: 7473:ae2a1efe4192 user: Jim Baker date: Sat Dec 27 08:58:27 2014 -0700 summary: Completes proxying of java.util.List objects Objects implementing the java.util.List interface are now fully proxied so they appear to be semantically equivalent to Python objects. This completes the fix for http://bugs.jython.org/issue2215, although similar treatment still needs to be done for java.util.Set. Important notes, including a backwards breaking change: * The remove method on such proxied List objects now follows Python sematics: a ValueError is raised if the item is not found in the java.util.List. In the past, Java semantics were used, with a boolean returned indicating if the item was removed or not. Given how this interacted with Jython, such remove invocations were in the past silent, and perhaps were actually a bug. In the Jython standard library, one module had to be updated (Lib/_socket.py, which provides the core socket and select semantics using Netty 4), so this backwards breaking change may have minimal impact. See https://docs.python.org/2/library/stdtypes.html#typesseq-mutable * List-like objects (those implementing __len__, __getitem__) cannot be used to initialize proxied java.util.List objects. This is a separate problem - they should be proxied so that they appear like List objects in Java space, and available as such for any Java code. * However, xrange objects can now be treated as a java.lang.Iterable, which is seen in similar construction tests. * No additional synchronization guarantees are provided on the underlying java.util.List. Use Collections.synchronizedList if this is needed. This contrasts with how Jython's builtin list object works, but is comparable to how we chose to proxy java.util.Map objects. files: Lib/_socket.py | 7 +- Lib/test/list_tests.py | 41 +- Lib/test/seq_tests.py | 409 ++++++++++++++ Lib/test/test_list_jy.py | 14 + src/org/python/core/JavaIterator.java | 20 + src/org/python/core/PyIterator.java | 21 + src/org/python/core/PyJavaType.java | 312 +++++++--- src/org/python/core/PyList.java | 20 + src/org/python/core/PyString.java | 11 + src/org/python/core/PyXRange.java | 23 + 10 files changed, 776 insertions(+), 102 deletions(-) diff --git a/Lib/_socket.py b/Lib/_socket.py --- a/Lib/_socket.py +++ b/Lib/_socket.py @@ -580,7 +580,7 @@ def initChannel(self, child_channel): child = ChildSocket(self.parent_socket) - log.debug("Initializing child %s", extra={"sock": self.parent_socket}) + log.debug("Initializing child %s", child, extra={"sock": self.parent_socket}) child.proto = IPPROTO_TCP child._init_client_mode(child_channel) @@ -732,7 +732,10 @@ self.selectors.addIfAbsent(selector) def _unregister_selector(self, selector): - return self.selectors.remove(selector) + try: + return self.selectors.remove(selector) + except ValueError: + return None def _notify_selectors(self, exception=None, hangup=False): for selector in self.selectors: diff --git a/Lib/test/list_tests.py b/Lib/test/list_tests.py --- a/Lib/test/list_tests.py +++ b/Lib/test/list_tests.py @@ -4,9 +4,14 @@ import sys import os +import unittest from test import test_support, seq_tests +if test_support.is_jython: + from java.util import List as JList + + class CommonTest(seq_tests.CommonTest): def test_init(self): @@ -40,12 +45,14 @@ self.assertEqual(str(a2), "[0, 1, 2]") self.assertEqual(repr(a2), "[0, 1, 2]") - a2.append(a2) - a2.append(3) - self.assertEqual(str(a2), "[0, 1, 2, [...], 3]") - self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]") + if not (test_support.is_jython and issubclass(self.type2test, JList)): + # Jython does not support shallow copies of object graphs + # when moving back and forth from Java object space + a2.append(a2) + a2.append(3) + self.assertEqual(str(a2), "[0, 1, 2, [...], 3]") + self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]") - #FIXME: not working on Jython if not test_support.is_jython: l0 = [] for i in xrange(sys.getrecursionlimit() + 100): @@ -53,6 +60,8 @@ self.assertRaises(RuntimeError, repr, l0) def test_print(self): + if test_support.is_jython and issubclass(self.type2test, JList): + raise unittest.SkipTest("Jython does not support shallow copies of object graphs") d = self.type2test(xrange(200)) d.append(d) d.extend(xrange(200,400)) @@ -184,10 +193,14 @@ a[:] = tuple(range(10)) self.assertEqual(a, self.type2test(range(10))) - self.assertRaises(TypeError, a.__setslice__, 0, 1, 5) + if not (test_support.is_jython and issubclass(self.type2test, JList)): + # no support for __setslice__ on Jython for + # java.util.List, given that method deprecated since 2.0! + self.assertRaises(TypeError, a.__setslice__, 0, 1, 5) self.assertRaises(TypeError, a.__setitem__, slice(0, 1, 5)) - self.assertRaises(TypeError, a.__setslice__) + if not (test_support.is_jython and issubclass(self.type2test, JList)): + self.assertRaises(TypeError, a.__setslice__) self.assertRaises(TypeError, a.__setitem__) def test_delslice(self): @@ -330,9 +343,12 @@ d = self.type2test(['a', 'b', BadCmp2(), 'c']) e = self.type2test(d) self.assertRaises(BadExc, d.remove, 'c') - for x, y in zip(d, e): - # verify that original order and values are retained. - self.assertIs(x, y) + if not (test_support.is_jython and issubclass(self.type2test, JList)): + # When converting back and forth to Java space, Jython does not + # maintain object identity + for x, y in zip(d, e): + # verify that original order and values are retained. + self.assertIs(x, y) def test_count(self): a = self.type2test([0, 1, 2])*3 @@ -452,8 +468,13 @@ def selfmodifyingComparison(x,y): z.append(1) return cmp(x, y) + + # Need to ensure the comparisons are actually executed by + # setting up a list + z = self.type2test(range(12)) self.assertRaises(ValueError, z.sort, selfmodifyingComparison) + z = self.type2test(range(12)) self.assertRaises(TypeError, z.sort, lambda x, y: 's') self.assertRaises(TypeError, z.sort, 42, 42, 42, 42) diff --git a/Lib/test/seq_tests.py b/Lib/test/seq_tests.py new file mode 100644 --- /dev/null +++ b/Lib/test/seq_tests.py @@ -0,0 +1,409 @@ +""" +Tests common to tuple, list and UserList.UserList +""" + +import unittest +import sys + +from test import test_support + +if test_support.is_jython: + from java.util import List as JList + +# Various iterables +# This is used for checking the constructor (here and in test_deque.py) +def iterfunc(seqn): + 'Regular generator' + for i in seqn: + yield i + +class Sequence: + 'Sequence using __getitem__' + def __init__(self, seqn): + self.seqn = seqn + def __getitem__(self, i): + return self.seqn[i] + +class IterFunc: + 'Sequence using iterator protocol' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def next(self): + if self.i >= len(self.seqn): raise StopIteration + v = self.seqn[self.i] + self.i += 1 + return v + +class IterGen: + 'Sequence using iterator protocol defined with a generator' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + for val in self.seqn: + yield val + +class IterNextOnly: + 'Missing __getitem__ and __iter__' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def next(self): + if self.i >= len(self.seqn): raise StopIteration + v = self.seqn[self.i] + self.i += 1 + return v + +class IterNoNext: + 'Iterator missing next()' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + +class IterGenExc: + 'Test propagation of exceptions' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def next(self): + 3 // 0 + +class IterFuncStop: + 'Test immediate stop' + def __init__(self, seqn): + pass + def __iter__(self): + return self + def next(self): + raise StopIteration + +from itertools import chain, imap +def itermulti(seqn): + 'Test multiple tiers of iterators' + return chain(imap(lambda x:x, iterfunc(IterGen(Sequence(seqn))))) + +class CommonTest(unittest.TestCase): + # The type to be tested + type2test = None + + def test_constructors(self): + l0 = [] + l1 = [0] + l2 = [0, 1] + + u = self.type2test() + u0 = self.type2test(l0) + u1 = self.type2test(l1) + u2 = self.type2test(l2) + + uu = self.type2test(u) + uu0 = self.type2test(u0) + uu1 = self.type2test(u1) + uu2 = self.type2test(u2) + + v = self.type2test(tuple(u)) + class OtherSeq: + def __init__(self, initseq): + self.__data = initseq + def __len__(self): + return len(self.__data) + def __getitem__(self, i): + return self.__data[i] + if not (test_support.is_jython and issubclass(self.type2test, JList)): + # Jython does not currently support in reflected args + # converting List-like objects to Lists. This lack of + # support should be fixed, but it's tricky. + s = OtherSeq(u0) + v0 = self.type2test(s) + self.assertEqual(len(v0), len(s)) + + s = "this is also a sequence" + vv = self.type2test(s) + self.assertEqual(len(vv), len(s)) + + + if test_support.is_jython and issubclass(self.type2test, JList): + # Ditto from above, we need to skip the rest of the test + return + + # Create from various iteratables + for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)): + for g in (Sequence, IterFunc, IterGen, + itermulti, iterfunc): + self.assertEqual(self.type2test(g(s)), self.type2test(s)) + self.assertEqual(self.type2test(IterFuncStop(s)), self.type2test()) + self.assertEqual(self.type2test(c for c in "123"), self.type2test("123")) + self.assertRaises(TypeError, self.type2test, IterNextOnly(s)) + self.assertRaises(TypeError, self.type2test, IterNoNext(s)) + self.assertRaises(ZeroDivisionError, self.type2test, IterGenExc(s)) + + def test_truth(self): + self.assertFalse(self.type2test()) + self.assertTrue(self.type2test([42])) + + def test_getitem(self): + u = self.type2test([0, 1, 2, 3, 4]) + for i in xrange(len(u)): + self.assertEqual(u[i], i) + self.assertEqual(u[long(i)], i) + for i in xrange(-len(u), -1): + self.assertEqual(u[i], len(u)+i) + self.assertEqual(u[long(i)], len(u)+i) + self.assertRaises(IndexError, u.__getitem__, -len(u)-1) + self.assertRaises(IndexError, u.__getitem__, len(u)) + self.assertRaises(ValueError, u.__getitem__, slice(0,10,0)) + + u = self.type2test() + self.assertRaises(IndexError, u.__getitem__, 0) + self.assertRaises(IndexError, u.__getitem__, -1) + + self.assertRaises(TypeError, u.__getitem__) + + a = self.type2test([10, 11]) + self.assertEqual(a[0], 10) + self.assertEqual(a[1], 11) + self.assertEqual(a[-2], 10) + self.assertEqual(a[-1], 11) + self.assertRaises(IndexError, a.__getitem__, -3) + self.assertRaises(IndexError, a.__getitem__, 3) + + def test_getslice(self): + l = [0, 1, 2, 3, 4] + u = self.type2test(l) + + self.assertEqual(u[0:0], self.type2test()) + self.assertEqual(u[1:2], self.type2test([1])) + self.assertEqual(u[-2:-1], self.type2test([3])) + self.assertEqual(u[-1000:1000], u) + self.assertEqual(u[1000:-1000], self.type2test([])) + self.assertEqual(u[:], u) + self.assertEqual(u[1:None], self.type2test([1, 2, 3, 4])) + self.assertEqual(u[None:3], self.type2test([0, 1, 2])) + + # Extended slices + self.assertEqual(u[::], u) + self.assertEqual(u[::2], self.type2test([0, 2, 4])) + self.assertEqual(u[1::2], self.type2test([1, 3])) + self.assertEqual(u[::-1], self.type2test([4, 3, 2, 1, 0])) + self.assertEqual(u[::-2], self.type2test([4, 2, 0])) + self.assertEqual(u[3::-2], self.type2test([3, 1])) + self.assertEqual(u[3:3:-2], self.type2test([])) + self.assertEqual(u[3:2:-2], self.type2test([3])) + self.assertEqual(u[3:1:-2], self.type2test([3])) + self.assertEqual(u[3:0:-2], self.type2test([3, 1])) + self.assertEqual(u[::-100], self.type2test([4])) + self.assertEqual(u[100:-100:], self.type2test([])) + self.assertEqual(u[-100:100:], u) + self.assertEqual(u[100:-100:-1], u[::-1]) + self.assertEqual(u[-100:100:-1], self.type2test([])) + self.assertEqual(u[-100L:100L:2L], self.type2test([0, 2, 4])) + + # Test extreme cases with long ints + a = self.type2test([0,1,2,3,4]) + self.assertEqual(a[ -pow(2,128L): 3 ], self.type2test([0,1,2])) + self.assertEqual(a[ 3: pow(2,145L) ], self.type2test([3,4])) + + if not (test_support.is_jython and issubclass(self.type2test, JList)): + # no support for __getslice__ on Jython for + # java.util.List, given that method deprecated since 2.0! + self.assertRaises(TypeError, u.__getslice__) + + def test_contains(self): + u = self.type2test([0, 1, 2]) + for i in u: + self.assertIn(i, u) + for i in min(u)-1, max(u)+1: + self.assertNotIn(i, u) + + self.assertRaises(TypeError, u.__contains__) + + def test_contains_fake(self): + class AllEq: + # Sequences must use rich comparison against each item + # (unless "is" is true, or an earlier item answered) + # So instances of AllEq must be found in all non-empty sequences. + def __eq__(self, other): + return True + __hash__ = None # Can't meet hash invariant requirements + self.assertNotIn(AllEq(), self.type2test([])) + self.assertIn(AllEq(), self.type2test([1])) + + def test_contains_order(self): + # Sequences must test in-order. If a rich comparison has side + # effects, these will be visible to tests against later members. + # In this test, the "side effect" is a short-circuiting raise. + class DoNotTestEq(Exception): + pass + class StopCompares: + def __eq__(self, other): + raise DoNotTestEq + + checkfirst = self.type2test([1, StopCompares()]) + self.assertIn(1, checkfirst) + checklast = self.type2test([StopCompares(), 1]) + self.assertRaises(DoNotTestEq, checklast.__contains__, 1) + + def test_len(self): + self.assertEqual(len(self.type2test()), 0) + self.assertEqual(len(self.type2test([])), 0) + self.assertEqual(len(self.type2test([0])), 1) + self.assertEqual(len(self.type2test([0, 1, 2])), 3) + + def test_minmax(self): + u = self.type2test([0, 1, 2]) + self.assertEqual(min(u), 0) + self.assertEqual(max(u), 2) + + def test_addmul(self): + u1 = self.type2test([0]) + u2 = self.type2test([0, 1]) + self.assertEqual(u1, u1 + self.type2test()) + self.assertEqual(u1, self.type2test() + u1) + self.assertEqual(u1 + self.type2test([1]), u2) + self.assertEqual(self.type2test([-1]) + u1, self.type2test([-1, 0])) + self.assertEqual(self.type2test(), u2*0) + self.assertEqual(self.type2test(), 0*u2) + self.assertEqual(self.type2test(), u2*0L) + self.assertEqual(self.type2test(), 0L*u2) + self.assertEqual(u2, u2*1) + self.assertEqual(u2, 1*u2) + self.assertEqual(u2, u2*1L) + self.assertEqual(u2, 1L*u2) + self.assertEqual(u2+u2, u2*2) + self.assertEqual(u2+u2, 2*u2) + self.assertEqual(u2+u2, u2*2L) + self.assertEqual(u2+u2, 2L*u2) + self.assertEqual(u2+u2+u2, u2*3) + self.assertEqual(u2+u2+u2, 3*u2) + + class subclass(self.type2test): + pass + u3 = subclass([0, 1]) + self.assertEqual(u3, u3*1) + self.assertIsNot(u3, u3*1) + + def test_iadd(self): + u = self.type2test([0, 1]) + u += self.type2test() + self.assertEqual(u, self.type2test([0, 1])) + u += self.type2test([2, 3]) + self.assertEqual(u, self.type2test([0, 1, 2, 3])) + u += self.type2test([4, 5]) + self.assertEqual(u, self.type2test([0, 1, 2, 3, 4, 5])) + + u = self.type2test("spam") + u += self.type2test("eggs") + self.assertEqual(u, self.type2test("spameggs")) + + def test_imul(self): + u = self.type2test([0, 1]) + u *= 3 + self.assertEqual(u, self.type2test([0, 1, 0, 1, 0, 1])) + + def test_getitemoverwriteiter(self): + # Verify that __getitem__ overrides are not recognized by __iter__ + class T(self.type2test): + def __getitem__(self, key): + return str(key) + '!!!' + self.assertEqual(iter(T((1,2))).next(), 1) + + def test_repeat(self): + for m in xrange(4): + s = tuple(range(m)) + for n in xrange(-3, 5): + self.assertEqual(self.type2test(s*n), self.type2test(s)*n) + self.assertEqual(self.type2test(s)*(-4), self.type2test([])) + self.assertEqual(id(s), id(s*1)) + + def test_bigrepeat(self): + import sys + if sys.maxint <= 2147483647: + x = self.type2test([0]) + x *= 2**16 + self.assertRaises(MemoryError, x.__mul__, 2**16) + if hasattr(x, '__imul__'): + self.assertRaises(MemoryError, x.__imul__, 2**16) + + def test_subscript(self): + a = self.type2test([10, 11]) + self.assertEqual(a.__getitem__(0L), 10) + self.assertEqual(a.__getitem__(1L), 11) + self.assertEqual(a.__getitem__(-2L), 10) + self.assertEqual(a.__getitem__(-1L), 11) + self.assertRaises(IndexError, a.__getitem__, -3) + self.assertRaises(IndexError, a.__getitem__, 3) + self.assertEqual(a.__getitem__(slice(0,1)), self.type2test([10])) + self.assertEqual(a.__getitem__(slice(1,2)), self.type2test([11])) + self.assertEqual(a.__getitem__(slice(0,2)), self.type2test([10, 11])) + self.assertEqual(a.__getitem__(slice(0,3)), self.type2test([10, 11])) + self.assertEqual(a.__getitem__(slice(3,5)), self.type2test([])) + self.assertRaises(ValueError, a.__getitem__, slice(0, 10, 0)) + self.assertRaises(TypeError, a.__getitem__, 'x') + + def test_count(self): + a = self.type2test([0, 1, 2])*3 + self.assertEqual(a.count(0), 3) + self.assertEqual(a.count(1), 3) + self.assertEqual(a.count(3), 0) + + self.assertRaises(TypeError, a.count) + + class BadExc(Exception): + pass + + class BadCmp: + def __eq__(self, other): + if other == 2: + raise BadExc() + return False + + self.assertRaises(BadExc, a.count, BadCmp()) + + def test_index(self): + u = self.type2test([0, 1]) + self.assertEqual(u.index(0), 0) + self.assertEqual(u.index(1), 1) + self.assertRaises(ValueError, u.index, 2) + + u = self.type2test([-2, -1, 0, 0, 1, 2]) + self.assertEqual(u.count(0), 2) + self.assertEqual(u.index(0), 2) + self.assertEqual(u.index(0, 2), 2) + self.assertEqual(u.index(-2, -10), 0) + self.assertEqual(u.index(0, 3), 3) + self.assertEqual(u.index(0, 3, 4), 3) + self.assertRaises(ValueError, u.index, 2, 0, -10) + + self.assertRaises(TypeError, u.index) + + class BadExc(Exception): + pass + + class BadCmp: + def __eq__(self, other): + if other == 2: + raise BadExc() + return False + + a = self.type2test([0, 1, 2, 3]) + self.assertRaises(BadExc, a.index, BadCmp()) + + a = self.type2test([-2, -1, 0, 0, 1, 2]) + self.assertEqual(a.index(0), 2) + self.assertEqual(a.index(0, 2), 2) + self.assertEqual(a.index(0, -4), 2) + self.assertEqual(a.index(-2, -10), 0) + self.assertEqual(a.index(0, 3), 3) + self.assertEqual(a.index(0, -3), 3) + self.assertEqual(a.index(0, 3, 4), 3) + self.assertEqual(a.index(0, -3, -2), 3) + self.assertEqual(a.index(0, -4*sys.maxint, 4*sys.maxint), 2) + self.assertRaises(ValueError, a.index, 0, 4*sys.maxint,-4*sys.maxint) + self.assertRaises(ValueError, a.index, 2, 0, -10) diff --git a/Lib/test/test_list_jy.py b/Lib/test/test_list_jy.py --- a/Lib/test/test_list_jy.py +++ b/Lib/test/test_list_jy.py @@ -214,6 +214,20 @@ type2test = ArrayList + def test_init(self): + # Iterable arg is optional + self.assertEqual(self.type2test([]), self.type2test()) + + # Unlike with builtin types, we do not guarantee objects can + # be overwritten; see corresponding tests + + # Mutables always return a new object + a = self.type2test([1, 2, 3]) + b = self.type2test(a) + self.assertNotEqual(id(a), id(b)) + self.assertEqual(a, b) + + def test_extend_java_ArrayList(self): jl = ArrayList([]) jl.extend([1,2]) diff --git a/src/org/python/core/JavaIterator.java b/src/org/python/core/JavaIterator.java new file mode 100644 --- /dev/null +++ b/src/org/python/core/JavaIterator.java @@ -0,0 +1,20 @@ +package org.python.core; + +import java.util.Iterator; + +public class JavaIterator extends PyIterator { + + final private Iterator proxy; + + public JavaIterator(Iterable proxy) { + this(proxy.iterator()); + } + + public JavaIterator(Iterator proxy) { + this.proxy = proxy; + } + + public PyObject __iternext__() { + return proxy.hasNext() ? Py.java2py(proxy.next()) : null; + } +} diff --git a/src/org/python/core/PyIterator.java b/src/org/python/core/PyIterator.java --- a/src/org/python/core/PyIterator.java +++ b/src/org/python/core/PyIterator.java @@ -1,7 +1,10 @@ // Copyright 2000 Finn Bock package org.python.core; +import java.util.ArrayList; +import java.util.Collection; import java.util.Iterator; +import java.util.List; /** * An abstract helper class useful when implementing an iterator object. This implementation supply @@ -62,4 +65,22 @@ } }; } + + @Override + public Object __tojava__(Class c) { + if (c.isAssignableFrom(Iterable.class)) { + return this; + } + if (c.isAssignableFrom(Iterator.class)) { + return iterator(); + } + if (c.isAssignableFrom(Collection.class)) { + List list = new ArrayList(); + for (Object obj : this) { + list.add(obj); + } + return list; + } + return super.__tojava__(c); + } } diff --git a/src/org/python/core/PyJavaType.java b/src/org/python/core/PyJavaType.java --- a/src/org/python/core/PyJavaType.java +++ b/src/org/python/core/PyJavaType.java @@ -16,6 +16,7 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import com.google.common.collect.Lists; import org.python.core.util.StringUtil; @@ -897,6 +898,126 @@ protected List asList(){ return (List)self.getJavaProxy(); } + + protected List newList() { + try { + return (List) asList().getClass().newInstance(); + } catch (IllegalAccessException e) { + throw Py.JavaError(e); + } catch (InstantiationException e) { + throw Py.JavaError(e); + } + } + } + + private static class ListMulProxyClass extends ListMethod { + protected ListMulProxyClass(String name, int numArgs) { + super(name, numArgs); + } + + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + int mult = obj.asInt(); + List newList = null; + // anything below 0 multiplier, we return an empty list + if (mult > 0) { + try { + newList = new ArrayList(jList.size() * mult); + // otherwise, extend it x times, where x is int-cast from obj + for (; mult > 0; mult--) { + for (Object entry : jList) { + newList.add(entry); + } + } + } catch (OutOfMemoryError t) { + throw Py.MemoryError(""); + } + } else { + newList = Collections.EMPTY_LIST; + } + return Py.java2py(newList); + } + } + + + private static class KV { + + private final PyObject key; + private final Object value; + + KV(PyObject key, Object value) { + this.key = key; + this.value = value; + } + } + + private static class KVComparator implements Comparator { + + private final PyObject cmp; + + KVComparator(PyObject cmp) { + this.cmp = cmp; + } + + public int compare(KV o1, KV o2) { + int result; + if (cmp != null && cmp != Py.None) { + PyObject pyresult = cmp.__call__(o1.key, o2.key); + if (pyresult instanceof PyInteger || pyresult instanceof PyLong) { + return pyresult.asInt(); + } else { + throw Py.TypeError( + String.format("comparison function must return int, not %.200s", + pyresult.getType().fastGetName())); + } + } else { + result = o1.key._cmp(o2.key); + } + return result; + } + + public boolean equals(Object o) { + if (o == this) { + return true; + } + + if (o instanceof KVComparator) { + return cmp.equals(((KVComparator) o).cmp); + } + return false; + } + } + + private synchronized static void list_sort(List list, PyObject cmp, PyObject key, boolean reverse) { + int size = list.size(); + final ArrayList decorated = new ArrayList(size); + for (Object value : list) { + PyObject pyvalue = Py.java2py(value); + if (key == null || key == Py.None) { + decorated.add(new KV(pyvalue, value)); + } + else { + decorated.add(new KV(key.__call__(pyvalue), value)); + } + } + // we will rebuild the list from the decorated version + list.clear(); + KVComparator c = new KVComparator(cmp); + if (reverse) { + Collections.reverse(decorated); // maintain stability of sort by reversing first + } + Collections.sort(decorated, c); + if (reverse) { + Collections.reverse(decorated); + } + boolean modified = list.size() > 0; + for (KV kv : decorated) { + list.add(kv.value); + } + if (modified) { + throw Py.ValueError("list modified during sort"); + } } private static class MapMethod extends PyBuiltinMethodNarrow { @@ -994,6 +1115,8 @@ } } + + /** * Build a map of common Java collection base types (Map, Iterable, etc) that need to be * injected with Python's equivalent types' builtin methods (__len__, __iter__, iteritems, etc). @@ -1024,8 +1147,20 @@ PyBuiltinMethodNarrow containsProxy = new PyBuiltinMethodNarrow("__contains__", 1) { @Override public PyObject __call__(PyObject obj) { - Object other = obj.__tojava__(Object.class); - boolean contained = ((Collection)self.getJavaProxy()).contains(other); + boolean contained = false; + Object proxy = obj.getJavaProxy(); + if (proxy == null) { + for (Object item : (Collection)self.getJavaProxy()) { + if (Py.java2py(item)._eq(obj).__nonzero__()) { + contained = true; + break; + } + } + } else { + Object other = obj.__tojava__(Object.class); + contained = ((Collection)self.getJavaProxy()).contains(other); + + } return contained ? Py.True : Py.False; } }; @@ -1484,13 +1619,14 @@ @Override public PyObject __call__(PyObject obj) { List jList = asList(); - if (obj instanceof Collection) { - jList.addAll((Collection) obj); - } else { - for (PyObject item : obj.asIterable()) { - jList.add(item); - } + List extension = new ArrayList(); + + // Extra step to build the extension list is necessary + // in case of adding to oneself + for (PyObject item : obj.asIterable()) { + extension.add(item); } + jList.addAll(extension); return Py.None; } }; @@ -1499,7 +1635,7 @@ public PyObject __call__(PyObject index, PyObject object) { List jlist = asList(); ListIndexDelegate lid = new ListIndexDelegate(jlist); - int idx = lid.fixBoundIndex(index.asIndex()); + int idx = lid.fixBoundIndex(index); jlist.add(idx, object); return Py.None; } @@ -1538,13 +1674,21 @@ public PyObject __call__(PyObject object, PyObject start, PyObject end) { List jlist = asList(); ListIndexDelegate lid = new ListIndexDelegate(jlist); - int st = lid.fixBoundIndex(start.asInt()); - int en = lid.fixBoundIndex(end.asInt()); - for (int i = st; i < en; i++) { - Object jobj = jlist.get(i); - if (Py.java2py(jobj)._eq(object).__nonzero__()) { - return Py.newInteger(i); + int start_index = lid.fixBoundIndex(start); + int end_index = lid.fixBoundIndex(end); + int i = start_index; + try { + for (ListIterator it = jlist.listIterator(start_index); it.hasNext(); i++) { + if (i == end_index) { + break; + } + Object jobj = it.next(); + if (Py.java2py(jobj)._eq(object).__nonzero__()) { + return Py.newInteger(i); + } } + } catch (ConcurrentModificationException e) { + throw Py.ValueError(object.toString() + " is not in list"); } throw Py.ValueError(object.toString() + " is not in list"); } @@ -1632,62 +1776,6 @@ return self; } }; - PyBuiltinMethodNarrow listRMulProxy = new ListMethod("__rmul__", 1) { - @Override - public PyObject __call__(PyObject obj) { - List jList = asList(); - int mult = obj.asInt(); - - List jClone; - try { - jClone = (List) jList.getClass().newInstance(); - } catch (IllegalAccessException e) { - throw Py.JavaError(e); - } catch (InstantiationException e) { - throw Py.JavaError(e); - } - - // anything below 0 multiplier, we return an empty list - if (mult > 0) { - // otherwise, extend it x times, where x is int-cast from obj - for (; mult > 0; mult--) { - for (Object entry : jList) { - jClone.add(entry); - } - } - } - - return Py.java2py(jClone); - } - }; - PyBuiltinMethodNarrow listMulProxy = new ListMethod("__mul__", 1) { - @Override - public PyObject __call__(PyObject obj) { - List jList = asList(); - int mult = obj.asInt(); - - List jClone; - try { - jClone = (List) jList.getClass().newInstance(); - } catch (IllegalAccessException e) { - throw Py.JavaError(e); - } catch (InstantiationException e) { - throw Py.JavaError(e); - } - - // anything below 0 multiplier, we return an empty list - if (mult > 0) { - // otherwise, extend it x times, where x is int-cast from obj - for (; mult > 0; mult--) { - for (Object entry : jList) { - jClone.add(entry); - } - } - } - - return Py.java2py(jClone); - } - }; PyBuiltinMethodNarrow listIMulProxy = new ListMethod("__imul__", 1) { @Override public PyObject __call__(PyObject obj) { @@ -1698,18 +1786,61 @@ if (mult <= 0) { jList.clear(); } else { - // otherwise, extend it (in-place) x times, where x is int-cast from obj - int originalSize = jList.size(); - for (mult = mult - 1; mult > 0; mult--) { - for (int i = 0; i < originalSize; i++) { - jList.add(jList.get(i)); + try { + if (jList instanceof ArrayList) { + ((ArrayList)jList).ensureCapacity(jList.size() * (mult - 1)); } + // otherwise, extend it (in-place) x times, where x is int-cast from obj + int originalSize = jList.size(); + for (mult = mult - 1; mult > 0; mult--) { + for (int i = 0; i < originalSize; i++) { + jList.add(jList.get(i)); + } + } + } catch (OutOfMemoryError t) { + throw Py.MemoryError(""); } } - return self; } }; + PyBuiltinMethodNarrow listSortProxy = new ListMethod("sort", 0, 3) { + @Override + public PyObject __call__() { + list_sort(asList(), Py.None, Py.None, false); + return Py.None; + } + + @Override + public PyObject __call__(PyObject cmp) { + list_sort(asList(), cmp, Py.None, false); + return Py.None; + } + + @Override + public PyObject __call__(PyObject cmp, PyObject key) { + list_sort(asList(), cmp, key, false); + return Py.None; + } + + @Override + public PyObject __call__(PyObject cmp, PyObject key, PyObject reverse) { + list_sort(asList(), cmp, key, reverse.__nonzero__()); + return Py.None; + } + + @Override + public PyObject __call__(PyObject[] args, String[] kwds) { + ArgParser ap = new ArgParser("list", args, kwds, new String[]{ + "cmp", "key", "reverse"}, 0); + PyObject cmp = ap.getPyObject(0, Py.None); + PyObject key = ap.getPyObject(1, Py.None); + PyObject reverse = ap.getPyObject(2, Py.False); + list_sort(asList(), cmp, key, reverse.__nonzero__()); + return Py.None; + } + }; + collectionProxies.put(List.class, new PyBuiltinMethod[] { listGetProxy, listSetProxy, @@ -1724,9 +1855,10 @@ listReverseProxy, listRAddProxy, listIAddProxy, - listRMulProxy, - listMulProxy, + new ListMulProxyClass("__mul__", 1), + new ListMulProxyClass("__rmul__", 1), listIMulProxy, + listSortProxy, }); postCollectionProxies.put(List.class, new PyBuiltinMethod[]{ listRemoveOverrideProxy, @@ -1789,17 +1921,19 @@ return list.size(); } - protected int fixBoundIndex(int index) { - int l = len(); - if (index < 0) { - index += l; - if (index < 0) { - index = 0; + protected int fixBoundIndex(PyObject index) { + PyInteger length = Py.newInteger(len()); + if (index._lt(Py.Zero).__nonzero__()) { + index = index._add(length); + if (index._lt(Py.Zero).__nonzero__()) { + index = Py.Zero; } - } else if (index > l) { - index = l; + } else if (index._gt(length).__nonzero__()) { + index = length; } - return index; + int i = index.asIndex(); + assert i >= 0; + return i; } @Override @@ -1828,8 +1962,6 @@ } } - - final private void setsliceList(int start, int stop, int step, List value) { if (step == 1) { list.subList(start, stop).clear(); diff --git a/src/org/python/core/PyList.java b/src/org/python/core/PyList.java --- a/src/org/python/core/PyList.java +++ b/src/org/python/core/PyList.java @@ -14,11 +14,13 @@ import java.util.Collections; import java.util.Comparator; import java.util.ConcurrentModificationException; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.ListIterator; import java.lang.reflect.Array; +import java.util.Map; @ExposedType(name = "list", base = PyObject.class, doc = BuiltinDocs.list_doc) public class PyList extends PySequenceList implements List { @@ -101,6 +103,21 @@ this(TYPE, listify(iter)); } + // refactor and put in Py presumably; + // presumably we can consume an arbitrary iterable too! + private static void addCollection(List list, Collection seq) { + Map seen = new HashMap(); + for (Object item : seq) { + long id = Py.java_obj_id(item); + PyObject seen_obj = seen.get(id); + if (seen_obj != null) { + seen_obj = Py.java2py(item); + seen.put(id, seen_obj); + } + list.add(seen_obj); + } + } + @ExposedNew @ExposedMethod(doc = BuiltinDocs.list___init___doc) final void list___init__(PyObject[] args, String[] kwds) { @@ -114,6 +131,9 @@ list.addAll(((PyList) seq).list); // don't convert } else if (seq instanceof PyTuple) { list.addAll(((PyTuple) seq).getList()); + } else if (seq.getClass().isAssignableFrom(Collection.class)) { + System.err.println("Adding from collection"); + addCollection(list, (Collection)seq); } else { for (PyObject item : seq.asIterable()) { append(item); diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -4,6 +4,9 @@ import java.lang.ref.Reference; import java.lang.ref.SoftReference; import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import org.python.core.buffer.BaseBuffer; import org.python.core.buffer.SimpleStringBuffer; @@ -718,6 +721,14 @@ } } + if (c.isAssignableFrom(Collection.class)) { + List list = new ArrayList(); + for (int i = 0; i < __len__(); i++) { + list.add(pyget(i).__tojava__(String.class)); + } + return list; + } + if (c.isInstance(this)) { return this; } diff --git a/src/org/python/core/PyXRange.java b/src/org/python/core/PyXRange.java --- a/src/org/python/core/PyXRange.java +++ b/src/org/python/core/PyXRange.java @@ -5,6 +5,11 @@ import org.python.expose.ExposedNew; import org.python.expose.ExposedType; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; + /** * The builtin xrange type. */ @@ -188,4 +193,22 @@ return String.format("xrange(%d, %d, %d)", start, stop, step); } } + + @Override + public Object __tojava__(Class c) { + if (c.isAssignableFrom(Iterable.class)) { + return new JavaIterator(range_iter()); + } + if (c.isAssignableFrom(Iterator.class)) { + return (new JavaIterator(range_iter())).iterator(); + } + if (c.isAssignableFrom(Collection.class)) { + List list = new ArrayList(); + for (Object obj : new JavaIterator(range_iter())) { + list.add(obj); + } + return list; + } + return super.__tojava__(c); + } } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Sun Dec 28 21:02:53 2014 From: jython-checkins at python.org (jim.baker) Date: Sun, 28 Dec 2014 20:02:53 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Refactor_PyJavaType_so_it_i?= =?utf-8?q?s_readable_again?= Message-ID: <20141228200241.125143.58730@psf.io> https://hg.python.org/jython/rev/2aa59e8e0bf8 changeset: 7474:2aa59e8e0bf8 user: Jim Baker date: Sun Dec 28 13:02:02 2014 -0700 summary: Refactor PyJavaType so it is readable again Because of our work to support semantic equivalence of java.util.{List|Map|Set} with Python builtins of list, dict, and set, PyJavaType has grown significantly. In particular, it is hard to follow the logic of adding of collection proxy and post proxy methods. Do a simple refactoring to address by moving collection logic into JavaProxyList, JavaProxyMap, and JavaProxySet (currently a placeholder for forthcoming work). Also apply the singleton holder pattern for building such proxies to simplify initialization. files: src/org/python/core/JavaProxyList.java | 637 ++++++ src/org/python/core/JavaProxyMap.java | 478 ++++ src/org/python/core/JavaProxySet.java | 48 + src/org/python/core/PyJavaType.java | 1214 +----------- 4 files changed, 1250 insertions(+), 1127 deletions(-) diff --git a/src/org/python/core/JavaProxyList.java b/src/org/python/core/JavaProxyList.java new file mode 100644 --- /dev/null +++ b/src/org/python/core/JavaProxyList.java @@ -0,0 +1,637 @@ +package org.python.core; + +/** + * Proxy Java objects implementing java.util.List with Python methods + * corresponding to the standard list type + */ + +import org.python.util.Generic; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.ConcurrentModificationException; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; + + +class JavaProxyList { + + private static class ListMethod extends PyBuiltinMethodNarrow { + protected ListMethod(String name, int numArgs) { + super(name, numArgs); + } + + protected ListMethod(String name, int minArgs, int maxArgs) { + super(name, minArgs, maxArgs); + } + + protected List asList() { + return (List) self.getJavaProxy(); + } + + protected List newList() { + try { + return (List) asList().getClass().newInstance(); + } catch (IllegalAccessException e) { + throw Py.JavaError(e); + } catch (InstantiationException e) { + throw Py.JavaError(e); + } + } + } + + protected static class ListIndexDelegate extends SequenceIndexDelegate { + + private final List list; + + public ListIndexDelegate(List list) { + this.list = list; + } + + @Override + public void delItem(int idx) { + list.remove(idx); + } + + @Override + public PyObject getItem(int idx) { + return Py.java2py(list.get(idx)); + } + + @Override + public PyObject getSlice(int start, int stop, int step) { + if (step > 0 && stop < start) { + stop = start; + } + int n = PySequence.sliceLength(start, stop, step); + List newList; + try { + newList = list.getClass().newInstance(); + } catch (Exception e) { + throw Py.JavaError(e); + } + int j = 0; + for (int i = start; j < n; i += step) { + newList.add(list.get(i)); + j++; + } + return Py.java2py(newList); + } + + @Override + public String getTypeName() { + return list.getClass().getName(); + } + + @Override + public int len() { + return list.size(); + } + + protected int fixBoundIndex(PyObject index) { + PyInteger length = Py.newInteger(len()); + if (index._lt(Py.Zero).__nonzero__()) { + index = index._add(length); + if (index._lt(Py.Zero).__nonzero__()) { + index = Py.Zero; + } + } else if (index._gt(length).__nonzero__()) { + index = length; + } + int i = index.asIndex(); + assert i >= 0; + return i; + } + + @Override + public void setItem(int idx, PyObject value) { + list.set(idx, value.__tojava__(Object.class)); + } + + @Override + public void setSlice(int start, int stop, int step, PyObject value) { + if (stop < start) { + stop = start; + } + if (value.javaProxy == this.list) { + List xs = Generic.list(); + xs.addAll(this.list); + setsliceList(start, stop, step, xs); + } else if (value instanceof PyList) { + setslicePyList(start, stop, step, (PyList) value); + } else { + Object valueList = value.__tojava__(List.class); + if (valueList != null && valueList != Py.NoConversion) { + setsliceList(start, stop, step, (List) valueList); + } else { + setsliceIterator(start, stop, step, value.asIterable().iterator()); + } + } + } + + final private void setsliceList(int start, int stop, int step, List value) { + if (step == 1) { + list.subList(start, stop).clear(); + list.addAll(start, value); + } else { + int size = list.size(); + Iterator iter = value.listIterator(); + for (int j = start; iter.hasNext(); j += step) { + Object item = iter.next(); + if (j >= size) { + list.add(item); + } else { + list.set(j, item); + } + } + } + } + + final private void setsliceIterator(int start, int stop, int step, Iterator iter) { + if (step == 1) { + List insertion = new ArrayList(); + if (iter != null) { + while (iter.hasNext()) { + insertion.add(iter.next().__tojava__(Object.class)); + } + } + list.subList(start, stop).clear(); + list.addAll(start, insertion); + } else { + int size = list.size(); + for (int j = start; iter.hasNext(); j += step) { + Object item = iter.next().__tojava__(Object.class); + if (j >= size) { + list.add(item); + } else { + list.set(j, item); + } + } + } + } + + final private void setslicePyList(int start, int stop, int step, PyList value) { + if (step == 1) { + list.subList(start, stop).clear(); + int n = value.getList().size(); + for (int i = 0, j = start; i < n; i++, j++) { + Object item = value.getList().get(i).__tojava__(Object.class); + list.add(j, item); + } + } else { + int size = list.size(); + Iterator iter = value.getList().listIterator(); + for (int j = start; iter.hasNext(); j += step) { + Object item = iter.next().__tojava__(Object.class); + if (j >= size) { + list.add(item); + } else { + list.set(j, item); + } + } + } + } + + + @Override + public void delItems(int start, int stop) { + int n = stop - start; + while (n-- > 0) { + delItem(start); + } + } + } + + + private static class ListMulProxyClass extends ListMethod { + protected ListMulProxyClass(String name, int numArgs) { + super(name, numArgs); + } + + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + int mult = obj.asInt(); + List newList = null; + // anything below 0 multiplier, we return an empty list + if (mult > 0) { + try { + newList = new ArrayList(jList.size() * mult); + // otherwise, extend it x times, where x is int-cast from obj + for (; mult > 0; mult--) { + for (Object entry : jList) { + newList.add(entry); + } + } + } catch (OutOfMemoryError t) { + throw Py.MemoryError(""); + } + } else { + newList = Collections.EMPTY_LIST; + } + return Py.java2py(newList); + } + } + + + private static class KV { + + private final PyObject key; + private final Object value; + + KV(PyObject key, Object value) { + this.key = key; + this.value = value; + } + } + + private static class KVComparator implements Comparator { + + private final PyObject cmp; + + KVComparator(PyObject cmp) { + this.cmp = cmp; + } + + public int compare(KV o1, KV o2) { + int result; + if (cmp != null && cmp != Py.None) { + PyObject pyresult = cmp.__call__(o1.key, o2.key); + if (pyresult instanceof PyInteger || pyresult instanceof PyLong) { + return pyresult.asInt(); + } else { + throw Py.TypeError( + String.format("comparison function must return int, not %.200s", + pyresult.getType().fastGetName())); + } + } else { + result = o1.key._cmp(o2.key); + } + return result; + } + + public boolean equals(Object o) { + if (o == this) { + return true; + } + + if (o instanceof KVComparator) { + return cmp.equals(((KVComparator) o).cmp); + } + return false; + } + } + + private synchronized static void list_sort(List list, PyObject cmp, PyObject key, boolean reverse) { + int size = list.size(); + final ArrayList decorated = new ArrayList(size); + for (Object value : list) { + PyObject pyvalue = Py.java2py(value); + if (key == null || key == Py.None) { + decorated.add(new KV(pyvalue, value)); + } else { + decorated.add(new KV(key.__call__(pyvalue), value)); + } + } + // we will rebuild the list from the decorated version + list.clear(); + KVComparator c = new KVComparator(cmp); + if (reverse) { + Collections.reverse(decorated); // maintain stability of sort by reversing first + } + Collections.sort(decorated, c); + if (reverse) { + Collections.reverse(decorated); + } + boolean modified = list.size() > 0; + for (KV kv : decorated) { + list.add(kv.value); + } + if (modified) { + throw Py.ValueError("list modified during sort"); + } + } + + private static final PyBuiltinMethodNarrow listGetProxy = new ListMethod("__getitem__", 1) { + @Override + public PyObject __call__(PyObject key) { + return new ListIndexDelegate(asList()).checkIdxAndGetItem(key); + } + }; + private static final PyBuiltinMethodNarrow listSetProxy = new ListMethod("__setitem__", 2) { + @Override + public PyObject __call__(PyObject key, PyObject value) { + new ListIndexDelegate(asList()).checkIdxAndSetItem(key, value); + return Py.None; + } + }; + private static final PyBuiltinMethodNarrow listRemoveProxy = new ListMethod("__delitem__", 1) { + @Override + public PyObject __call__(PyObject key) { + new ListIndexDelegate(asList()).checkIdxAndDelItem(key); + return Py.None; + } + }; + private static final PyBuiltinMethodNarrow listEqProxy = new ListMethod("__eq__", 1) { + @Override + public PyObject __call__(PyObject other) { + List jList = asList(); + if (other.getType().isSubType(PyList.TYPE)) { + PyList oList = (PyList) other; + if (jList.size() != oList.size()) { + return Py.False; + } + for (int i = 0; i < jList.size(); i++) { + if (!Py.java2py(jList.get(i))._eq(oList.pyget(i)).__nonzero__()) { + return Py.False; + } + } + return Py.True; + } else { + Object oj = other.getJavaProxy(); + if (oj instanceof List) { + List oList = (List) oj; + if (jList.size() != oList.size()) { + return Py.False; + } + for (int i = 0; i < jList.size(); i++) { + if (!Py.java2py(jList.get(i))._eq( + Py.java2py(oList.get(i))).__nonzero__()) { + return Py.False; + } + } + return Py.True; + } else { + return null; + } + } + } + }; + private static final PyBuiltinMethodNarrow listAppendProxy = new ListMethod("append", 1) { + @Override + public PyObject __call__(PyObject value) { + asList().add(value); + return Py.None; + } + }; + private static final PyBuiltinMethodNarrow listExtendProxy = new ListMethod("extend", 1) { + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + List extension = new ArrayList(); + + // Extra step to build the extension list is necessary + // in case of adding to oneself + for (PyObject item : obj.asIterable()) { + extension.add(item); + } + jList.addAll(extension); + return Py.None; + } + }; + private static final PyBuiltinMethodNarrow listInsertProxy = new ListMethod("insert", 2) { + @Override + public PyObject __call__(PyObject index, PyObject object) { + List jlist = asList(); + ListIndexDelegate lid = new ListIndexDelegate(jlist); + int idx = lid.fixBoundIndex(index); + jlist.add(idx, object); + return Py.None; + } + }; + private static final PyBuiltinMethodNarrow listPopProxy = new ListMethod("pop", 0, 1) { + @Override + public PyObject __call__() { + return __call__(Py.newInteger(-1)); + } + + @Override + public PyObject __call__(PyObject index) { + List jlist = asList(); + if (jlist.isEmpty()) { + throw Py.IndexError("pop from empty list"); + } + ListIndexDelegate ldel = new ListIndexDelegate(jlist); + PyObject item = ldel.checkIdxAndFindItem(index.asInt()); + if (item == null) { + throw Py.IndexError("pop index out of range"); + } else { + ldel.checkIdxAndDelItem(index); + return item; + } + } + }; + private static final PyBuiltinMethodNarrow listIndexProxy = new ListMethod("index", 1, 3) { + @Override + public PyObject __call__(PyObject object) { + return __call__(object, Py.newInteger(0), Py.newInteger(asList().size())); + } + + @Override + public PyObject __call__(PyObject object, PyObject start) { + return __call__(object, start, Py.newInteger(asList().size())); + } + + @Override + public PyObject __call__(PyObject object, PyObject start, PyObject end) { + List jlist = asList(); + ListIndexDelegate lid = new ListIndexDelegate(jlist); + int start_index = lid.fixBoundIndex(start); + int end_index = lid.fixBoundIndex(end); + int i = start_index; + try { + for (ListIterator it = jlist.listIterator(start_index); it.hasNext(); i++) { + if (i == end_index) { + break; + } + Object jobj = it.next(); + if (Py.java2py(jobj)._eq(object).__nonzero__()) { + return Py.newInteger(i); + } + } + } catch (ConcurrentModificationException e) { + throw Py.ValueError(object.toString() + " is not in list"); + } + throw Py.ValueError(object.toString() + " is not in list"); + } + }; + private static final PyBuiltinMethodNarrow listCountProxy = new ListMethod("count", 1) { + @Override + public PyObject __call__(PyObject object) { + int count = 0; + List jlist = asList(); + for (int i = 0; i < jlist.size(); i++) { + Object jobj = jlist.get(i); + if (Py.java2py(jobj)._eq(object).__nonzero__()) { + ++count; + } + } + return Py.newInteger(count); + } + }; + private static final PyBuiltinMethodNarrow listReverseProxy = new ListMethod("reverse", 0) { + @Override + public PyObject __call__() { + List jlist = asList(); + Collections.reverse(jlist); + return Py.None; + } + }; + private static final PyBuiltinMethodNarrow listRemoveOverrideProxy = new ListMethod("remove", 1) { + @Override + public PyObject __call__(PyObject object) { + List jlist = asList(); + for (int i = 0; i < jlist.size(); i++) { + Object jobj = jlist.get(i); + if (Py.java2py(jobj)._eq(object).__nonzero__()) { + jlist.remove(i); + return Py.None; + } + } + throw Py.ValueError(object.toString() + " is not in list"); + } + }; + private static final PyBuiltinMethodNarrow listRAddProxy = new ListMethod("__radd__", 1) { + @Override + public PyObject __call__(PyObject obj) { + // first, clone the self list + List jList = asList(); + List jClone; + try { + jClone = (List) jList.getClass().newInstance(); + } catch (IllegalAccessException e) { + throw Py.JavaError(e); + } catch (InstantiationException e) { + throw Py.JavaError(e); + } + for (Object entry : jList) { + jClone.add(entry); + } + + // then, extend it with elements from the other list + // (but, since this is reverse add, we are technically + // pre-pending the clone with elements from the other list) + if (obj instanceof Collection) { + jClone.addAll(0, (Collection) obj); + } else { + int i = 0; + for (PyObject item : obj.asIterable()) { + jClone.add(i, item); + i++; + } + } + + return Py.java2py(jClone); + } + }; + private static final PyBuiltinMethodNarrow listIAddProxy = new ListMethod("__iadd__", 1) { + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + if (obj instanceof Collection) { + jList.addAll((Collection) obj); + } else { + for (PyObject item : obj.asIterable()) { + jList.add(item); + } + } + return self; + } + }; + private static final PyBuiltinMethodNarrow listIMulProxy = new ListMethod("__imul__", 1) { + @Override + public PyObject __call__(PyObject obj) { + List jList = asList(); + int mult = obj.asInt(); + + // anything below 0 multiplier, we clear the list + if (mult <= 0) { + jList.clear(); + } else { + try { + if (jList instanceof ArrayList) { + ((ArrayList) jList).ensureCapacity(jList.size() * (mult - 1)); + } + // otherwise, extend it (in-place) x times, where x is int-cast from obj + int originalSize = jList.size(); + for (mult = mult - 1; mult > 0; mult--) { + for (int i = 0; i < originalSize; i++) { + jList.add(jList.get(i)); + } + } + } catch (OutOfMemoryError t) { + throw Py.MemoryError(""); + } + } + return self; + } + }; + private static final PyBuiltinMethodNarrow listSortProxy = new ListMethod("sort", 0, 3) { + @Override + public PyObject __call__() { + list_sort(asList(), Py.None, Py.None, false); + return Py.None; + } + + @Override + public PyObject __call__(PyObject cmp) { + list_sort(asList(), cmp, Py.None, false); + return Py.None; + } + + @Override + public PyObject __call__(PyObject cmp, PyObject key) { + list_sort(asList(), cmp, key, false); + return Py.None; + } + + @Override + public PyObject __call__(PyObject cmp, PyObject key, PyObject reverse) { + list_sort(asList(), cmp, key, reverse.__nonzero__()); + return Py.None; + } + + @Override + public PyObject __call__(PyObject[] args, String[] kwds) { + ArgParser ap = new ArgParser("list", args, kwds, new String[]{ + "cmp", "key", "reverse"}, 0); + PyObject cmp = ap.getPyObject(0, Py.None); + PyObject key = ap.getPyObject(1, Py.None); + PyObject reverse = ap.getPyObject(2, Py.False); + list_sort(asList(), cmp, key, reverse.__nonzero__()); + return Py.None; + } + }; + + static PyBuiltinMethod[] getProxyMethods() { + return new PyBuiltinMethod[]{ + listGetProxy, + listSetProxy, + listEqProxy, + listRemoveProxy, + listAppendProxy, + listExtendProxy, + listInsertProxy, + listPopProxy, + listIndexProxy, + listCountProxy, + listReverseProxy, + listRAddProxy, + listIAddProxy, + new ListMulProxyClass("__mul__", 1), + new ListMulProxyClass("__rmul__", 1), + listIMulProxy, + listSortProxy, + }; + } + + static PyBuiltinMethod[] getPostProxyMethods() { + return new PyBuiltinMethod[]{ + listRemoveOverrideProxy + }; + } + +} diff --git a/src/org/python/core/JavaProxyMap.java b/src/org/python/core/JavaProxyMap.java new file mode 100644 --- /dev/null +++ b/src/org/python/core/JavaProxyMap.java @@ -0,0 +1,478 @@ +package org.python.core; + +import java.util.Iterator; +import java.util.Map; +import java.util.Set; + +/** + * Proxy Java objects implementing java.util.List with Python methods + * corresponding to the standard list type + */ + + +class JavaProxyMap { + + private static class MapMethod extends PyBuiltinMethodNarrow { + protected MapMethod(String name, int numArgs) { + super(name, numArgs); + } + + protected MapMethod(String name, int minArgs, int maxArgs) { + super(name, minArgs, maxArgs); + } + + protected Map asMap() { + return (Map) self.getJavaProxy(); + } + } + + private static class MapClassMethod extends PyBuiltinClassMethodNarrow { + protected MapClassMethod(String name, int minArgs, int maxArgs) { + super(name, minArgs, maxArgs); + } + + protected Class asClass() { + return (Class) self.getJavaProxy(); + } + } + + private static PyObject mapEq(PyObject self, PyObject other) { + Map selfMap = ((Map) self.getJavaProxy()); + if (other.getType().isSubType(PyDictionary.TYPE)) { + PyDictionary oDict = (PyDictionary) other; + if (selfMap.size() != oDict.size()) { + return Py.False; + } + for (Object jkey : selfMap.keySet()) { + Object jval = selfMap.get(jkey); + PyObject oVal = oDict.__finditem__(Py.java2py(jkey)); + if (oVal == null) { + return Py.False; + } + if (!Py.java2py(jval)._eq(oVal).__nonzero__()) { + return Py.False; + } + } + return Py.True; + } else { + Object oj = other.getJavaProxy(); + if (oj instanceof Map) { + Map oMap = (Map) oj; + return Py.newBoolean(selfMap.equals(oMap)); + } else { + return null; + } + } + } + + // Map ordering comparisons (lt, le, gt, ge) are based on the key sets; + // we just define mapLe + mapEq for total ordering of such key sets + private static PyObject mapLe(PyObject self, PyObject other) { + Set selfKeys = ((Map) self.getJavaProxy()).keySet(); + if (other.getType().isSubType(PyDictionary.TYPE)) { + PyDictionary oDict = (PyDictionary) other; + for (Object jkey : selfKeys) { + if (!oDict.__contains__(Py.java2py(jkey))) { + return Py.False; + } + } + return Py.True; + } else { + Object oj = other.getJavaProxy(); + if (oj instanceof Map) { + Map oMap = (Map) oj; + return Py.newBoolean(oMap.keySet().containsAll(selfKeys)); + } else { + return null; + } + } + } + + // Map doesn't extend Collection, so it needs its own version of len, iter and contains + private static final PyBuiltinMethodNarrow mapLenProxy = new MapMethod("__len__", 0) { + @Override + public PyObject __call__() { + return Py.java2py(asMap().size()); + } + }; + private static final PyBuiltinMethodNarrow mapReprProxy = new MapMethod("__repr__", 0) { + @Override + public PyObject __call__() { + StringBuilder repr = new StringBuilder("{"); + for (Map.Entry entry : asMap().entrySet()) { + Object jkey = entry.getKey(); + Object jval = entry.getValue(); + repr.append(jkey.toString()); + repr.append(": "); + repr.append(jval == asMap() ? "{...}" : (jval == null ? "None" : jval.toString())); + repr.append(", "); + } + int lastindex = repr.lastIndexOf(", "); + if (lastindex > -1) { + repr.delete(lastindex, lastindex + 2); + } + repr.append("}"); + return new PyString(repr.toString()); + } + }; + private static final PyBuiltinMethodNarrow mapEqProxy = new MapMethod("__eq__", 1) { + @Override + public PyObject __call__(PyObject other) { + return mapEq(self, other); + } + }; + private static final PyBuiltinMethodNarrow mapLeProxy = new MapMethod("__le__", 1) { + @Override + public PyObject __call__(PyObject other) { + return mapLe(self, other); + } + }; + private static final PyBuiltinMethodNarrow mapGeProxy = new MapMethod("__ge__", 1) { + @Override + public PyObject __call__(PyObject other) { + return (mapLe(self, other).__not__()).__or__(mapEq(self, other)); + } + }; + private static final PyBuiltinMethodNarrow mapLtProxy = new MapMethod("__lt__", 1) { + @Override + public PyObject __call__(PyObject other) { + return mapLe(self, other).__and__(mapEq(self, other).__not__()); + } + }; + private static final PyBuiltinMethodNarrow mapGtProxy = new MapMethod("__gt__", 1) { + @Override + public PyObject __call__(PyObject other) { + return mapLe(self, other).__not__(); + } + }; + private static final PyBuiltinMethodNarrow mapIterProxy = new MapMethod("__iter__", 0) { + @Override + public PyObject __call__() { + return new JavaIterator(asMap().keySet()); + } + }; + private static final PyBuiltinMethodNarrow mapContainsProxy = new MapMethod("__contains__", 1) { + @Override + public PyObject __call__(PyObject obj) { + Object other = obj.__tojava__(Object.class); + return asMap().containsKey(other) ? Py.True : Py.False; + } + }; + // "get" needs to override java.util.Map#get() in its subclasses, too, so this needs to be injected last + // (i.e. when HashMap is loaded not when it is recursively loading its super-type Map) + private static final PyBuiltinMethodNarrow mapGetProxy = new MapMethod("get", 1, 2) { + @Override + public PyObject __call__(PyObject key) { + return __call__(key, Py.None); + } + + @Override + public PyObject __call__(PyObject key, PyObject _default) { + Object jkey = Py.tojava(key, Object.class); + if (asMap().containsKey(jkey)) { + return Py.java2py(asMap().get(jkey)); + } else { + return _default; + } + } + }; + private static final PyBuiltinMethodNarrow mapGetItemProxy = new MapMethod("__getitem__", 1) { + @Override + public PyObject __call__(PyObject key) { + Object jkey = Py.tojava(key, Object.class); + if (asMap().containsKey(jkey)) { + return Py.java2py(asMap().get(jkey)); + } else { + throw Py.KeyError(key); + } + } + }; + private static final PyBuiltinMethodNarrow mapPutProxy = new MapMethod("__setitem__", 2) { + @Override + public PyObject __call__(PyObject key, PyObject value) { + asMap().put(Py.tojava(key, Object.class), + value == Py.None ? Py.None : Py.tojava(value, Object.class)); + return Py.None; + } + }; + private static final PyBuiltinMethodNarrow mapRemoveProxy = new MapMethod("__delitem__", 1) { + @Override + public PyObject __call__(PyObject key) { + Object jkey = Py.tojava(key, Object.class); + if (asMap().remove(jkey) == null) { + throw Py.KeyError(key); + } + return Py.None; + } + }; + private static final PyBuiltinMethodNarrow mapIterItemsProxy = new MapMethod("iteritems", 0) { + @Override + public PyObject __call__() { + final Iterator> entrySetIterator = asMap().entrySet().iterator(); + return new PyIterator() { + @Override + public PyObject __iternext__() { + if (entrySetIterator.hasNext()) { + Map.Entry nextEntry = entrySetIterator.next(); + // yield a Python tuple object (key, value) + return new PyTuple(Py.java2py(nextEntry.getKey()), + Py.java2py(nextEntry.getValue())); + } + return null; + } + }; + } + }; + private static final PyBuiltinMethodNarrow mapHasKeyProxy = new MapMethod("has_key", 1) { + @Override + public PyObject __call__(PyObject key) { + return asMap().containsKey(Py.tojava(key, Object.class)) ? Py.True : Py.False; + } + }; + private static final PyBuiltinMethodNarrow mapKeysProxy = new MapMethod("keys", 0) { + @Override + public PyObject __call__() { + PyList keys = new PyList(); + for (Object key : asMap().keySet()) { + keys.add(Py.java2py(key)); + } + return keys; + } + }; + private static final PyBuiltinMethod mapValuesProxy = new MapMethod("values", 0) { + @Override + public PyObject __call__() { + PyList values = new PyList(); + for (Object value : asMap().values()) { + values.add(Py.java2py(value)); + } + return values; + } + }; + private static final PyBuiltinMethodNarrow mapSetDefaultProxy = new MapMethod("setdefault", 1, 2) { + @Override + public PyObject __call__(PyObject key) { + return __call__(key, Py.None); + } + + @Override + public PyObject __call__(PyObject key, PyObject _default) { + Object jkey = Py.tojava(key, Object.class); + Object jval = asMap().get(jkey); + if (jval == null) { + asMap().put(jkey, _default == Py.None ? Py.None : Py.tojava(_default, Object.class)); + return _default; + } + return Py.java2py(jval); + } + }; + private static final PyBuiltinMethodNarrow mapPopProxy = new MapMethod("pop", 1, 2) { + @Override + public PyObject __call__(PyObject key) { + return __call__(key, null); + } + + @Override + public PyObject __call__(PyObject key, PyObject _default) { + Object jkey = Py.tojava(key, Object.class); + if (asMap().containsKey(jkey)) { + PyObject value = Py.java2py(asMap().remove(jkey)); + assert (value != null); + return Py.java2py(value); + } else { + if (_default == null) { + throw Py.KeyError(key); + } + return _default; + } + } + }; + private static final PyBuiltinMethodNarrow mapPopItemProxy = new MapMethod("popitem", 0) { + @Override + public PyObject __call__() { + if (asMap().size() == 0) { + throw Py.KeyError("popitem(): map is empty"); + } + Object key = asMap().keySet().toArray()[0]; + Object val = asMap().remove(key); + return Py.java2py(val); + } + }; + private static final PyBuiltinMethodNarrow mapItemsProxy = new MapMethod("items", 0) { + @Override + public PyObject __call__() { + PyList items = new PyList(); + for (Map.Entry entry : asMap().entrySet()) { + items.add(new PyTuple(Py.java2py(entry.getKey()), + Py.java2py(entry.getValue()))); + } + return items; + } + }; + private static final PyBuiltinMethodNarrow mapCopyProxy = new MapMethod("copy", 0) { + @Override + public PyObject __call__() { + Map jmap = asMap(); + Map jclone; + try { + jclone = (Map) jmap.getClass().newInstance(); + } catch (IllegalAccessException e) { + throw Py.JavaError(e); + } catch (InstantiationException e) { + throw Py.JavaError(e); + } + for (Map.Entry entry : jmap.entrySet()) { + jclone.put(entry.getKey(), entry.getValue()); + } + return Py.java2py(jclone); + } + }; + private static final PyBuiltinMethodNarrow mapUpdateProxy = new MapMethod("update", 0, 1) { + private Map jmap; + + @Override + public PyObject __call__() { + return Py.None; + } + + @Override + public PyObject __call__(PyObject other) { + // `other` is either another dict-like object, or an iterable of key/value pairs (as tuples + // or other iterables of length two) + return __call__(new PyObject[]{other}, new String[]{}); + } + + @Override + public PyObject __call__(PyObject[] args, String[] keywords) { + if ((args.length - keywords.length) != 1) { + throw info.unexpectedCall(args.length, false); + } + jmap = asMap(); + PyObject other = args[0]; + // update with entries from `other` (adapted from their equivalent in PyDictionary#update) + Object proxy = other.getJavaProxy(); + if (proxy instanceof Map) { + merge((Map) proxy); + } else if (other.__findattr__("keys") != null) { + merge(other); + } else { + mergeFromSeq(other); + } + // update with entries from keyword arguments + for (int i = 0; i < keywords.length; i++) { + String jkey = keywords[i]; + PyObject value = args[1 + i]; + jmap.put(jkey, Py.tojava(value, Object.class)); + } + return Py.None; + } + + private void merge(Map other) { + for (Map.Entry entry : other.entrySet()) { + jmap.put(entry.getKey(), entry.getValue()); + } + } + + private void merge(PyObject other) { + if (other instanceof PyDictionary) { + jmap.putAll(((PyDictionary) other).getMap()); + } else if (other instanceof PyStringMap) { + mergeFromKeys(other, ((PyStringMap) other).keys()); + } else { + mergeFromKeys(other, other.invoke("keys")); + } + } + + private void mergeFromKeys(PyObject other, PyObject keys) { + for (PyObject key : keys.asIterable()) { + jmap.put(Py.tojava(key, Object.class), + Py.tojava(other.__getitem__(key), Object.class)); + } + } + + private void mergeFromSeq(PyObject other) { + PyObject pairs = other.__iter__(); + PyObject pair; + + for (int i = 0; (pair = pairs.__iternext__()) != null; i++) { + try { + pair = PySequence.fastSequence(pair, ""); + } catch (PyException pye) { + if (pye.match(Py.TypeError)) { + throw Py.TypeError(String.format("cannot convert dictionary update sequence " + + "element #%d to a sequence", i)); + } + throw pye; + } + int n; + if ((n = pair.__len__()) != 2) { + throw Py.ValueError(String.format("dictionary update sequence element #%d " + + "has length %d; 2 is required", i, n)); + } + jmap.put(Py.tojava(pair.__getitem__(0), Object.class), + Py.tojava(pair.__getitem__(1), Object.class)); + } + } + }; + private static final PyBuiltinClassMethodNarrow mapFromKeysProxy = new MapClassMethod("fromkeys", 1, 2) { + @Override + public PyObject __call__(PyObject keys) { + return __call__(keys, null); + } + + @Override + public PyObject __call__(PyObject keys, PyObject _default) { + Object defobj = _default == null ? Py.None : Py.tojava(_default, Object.class); + Class theClass = asClass(); + try { + // always injected to java.util.Map, so we know the class object we get from asClass is subtype of java.util.Map + Map theMap = (Map) theClass.newInstance(); + for (PyObject key : keys.asIterable()) { + theMap.put(Py.tojava(key, Object.class), defobj); + } + return Py.java2py(theMap); + } catch (InstantiationException e) { + throw Py.JavaError(e); + } catch (IllegalAccessException e) { + throw Py.JavaError(e); + } + } + }; + + static PyBuiltinMethod[] getProxyMethods() { + return new PyBuiltinMethod[]{ + mapLenProxy, + // map IterProxy can conflict with Iterable.class; + // fix after the fact in handleMroError + mapIterProxy, + mapReprProxy, + mapEqProxy, + mapLeProxy, + mapLtProxy, + mapGeProxy, + mapGtProxy, + mapContainsProxy, + mapGetItemProxy, + mapPutProxy, + mapRemoveProxy, + mapIterItemsProxy, + mapHasKeyProxy, + mapKeysProxy, + mapSetDefaultProxy, + mapPopProxy, + mapPopItemProxy, + mapItemsProxy, + mapCopyProxy, + mapUpdateProxy, + mapFromKeysProxy // class method + + }; + } + + static PyBuiltinMethod[] getPostProxyMethods() { + return new PyBuiltinMethod[]{ + mapGetProxy, + mapValuesProxy + }; + } +} diff --git a/src/org/python/core/JavaProxySet.java b/src/org/python/core/JavaProxySet.java new file mode 100644 --- /dev/null +++ b/src/org/python/core/JavaProxySet.java @@ -0,0 +1,48 @@ +package org.python.core; + +import java.util.Set; + +/** Proxy objects implementing java.util.Set */ + +class JavaProxySet { + + private static class SetMethod extends PyBuiltinMethodNarrow { + protected SetMethod(String name, int numArgs) { + super(name, numArgs); + } + + protected SetMethod(String name, int minArgs, int maxArgs) { + super(name, minArgs, maxArgs); + } + + protected Set asSet(){ + return (Set)self.getJavaProxy(); + } + + protected Set newSet() { + try { + return (Set) asSet().getClass().newInstance(); + } catch (IllegalAccessException e) { + throw Py.JavaError(e); + } catch (InstantiationException e) { + throw Py.JavaError(e); + } + } + } + + private static final PyBuiltinMethodNarrow setIsDisjointProxy = new SetMethod("isdisjoint", 1) { + @Override + public PyObject __call__(PyObject other) { + return Py.None; + } + }; + + static PyBuiltinMethod[] getProxyMethods() { + return new PyBuiltinMethod[]{}; + } + + static PyBuiltinMethod[] getPostProxyMethods() { + return new PyBuiltinMethod[]{}; + } + +} diff --git a/src/org/python/core/PyJavaType.java b/src/org/python/core/PyJavaType.java --- a/src/org/python/core/PyJavaType.java +++ b/src/org/python/core/PyJavaType.java @@ -15,13 +15,22 @@ import java.lang.reflect.Member; import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.EventListener; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; -import com.google.common.collect.Lists; import org.python.core.util.StringUtil; import org.python.util.Generic; + public class PyJavaType extends PyType { private final static Class[] OO = {PyObject.class, PyObject.class}; @@ -59,8 +68,7 @@ java.net.URI.class, java.util.concurrent.TimeUnit.class); - private static Map, PyBuiltinMethod[]> collectionProxies; - private static Map, PyBuiltinMethod[]> postCollectionProxies; + /** * Other Java classes this type has MRO conflicts with. This doesn't matter for Java method @@ -531,7 +539,7 @@ addMethod(meth); } } - // allow for some methods to override the Java type's as a late injection + // allow for some methods to override the Java type's methods as a late injection for (Class type : getPostCollectionProxies().keySet()) { if (type.isAssignableFrom(forClass)) { for (PyBuiltinMethod meth : getPostCollectionProxies().get(type)) { @@ -869,181 +877,6 @@ } } - private static class IteratorIter extends PyIterator { - - private Iterator proxy; - - public IteratorIter(Iterable proxy) { - this(proxy.iterator()); - } - - public IteratorIter(Iterator proxy) { - this.proxy = proxy; - } - - public PyObject __iternext__() { - return proxy.hasNext() ? Py.java2py(proxy.next()) : null; - } - } - - private static class ListMethod extends PyBuiltinMethodNarrow { - protected ListMethod(String name, int numArgs) { - super(name, numArgs); - } - - protected ListMethod(String name, int minArgs, int maxArgs) { - super(name, minArgs, maxArgs); - } - - protected List asList(){ - return (List)self.getJavaProxy(); - } - - protected List newList() { - try { - return (List) asList().getClass().newInstance(); - } catch (IllegalAccessException e) { - throw Py.JavaError(e); - } catch (InstantiationException e) { - throw Py.JavaError(e); - } - } - } - - private static class ListMulProxyClass extends ListMethod { - protected ListMulProxyClass(String name, int numArgs) { - super(name, numArgs); - } - - @Override - public PyObject __call__(PyObject obj) { - List jList = asList(); - int mult = obj.asInt(); - List newList = null; - // anything below 0 multiplier, we return an empty list - if (mult > 0) { - try { - newList = new ArrayList(jList.size() * mult); - // otherwise, extend it x times, where x is int-cast from obj - for (; mult > 0; mult--) { - for (Object entry : jList) { - newList.add(entry); - } - } - } catch (OutOfMemoryError t) { - throw Py.MemoryError(""); - } - } else { - newList = Collections.EMPTY_LIST; - } - return Py.java2py(newList); - } - } - - - private static class KV { - - private final PyObject key; - private final Object value; - - KV(PyObject key, Object value) { - this.key = key; - this.value = value; - } - } - - private static class KVComparator implements Comparator { - - private final PyObject cmp; - - KVComparator(PyObject cmp) { - this.cmp = cmp; - } - - public int compare(KV o1, KV o2) { - int result; - if (cmp != null && cmp != Py.None) { - PyObject pyresult = cmp.__call__(o1.key, o2.key); - if (pyresult instanceof PyInteger || pyresult instanceof PyLong) { - return pyresult.asInt(); - } else { - throw Py.TypeError( - String.format("comparison function must return int, not %.200s", - pyresult.getType().fastGetName())); - } - } else { - result = o1.key._cmp(o2.key); - } - return result; - } - - public boolean equals(Object o) { - if (o == this) { - return true; - } - - if (o instanceof KVComparator) { - return cmp.equals(((KVComparator) o).cmp); - } - return false; - } - } - - private synchronized static void list_sort(List list, PyObject cmp, PyObject key, boolean reverse) { - int size = list.size(); - final ArrayList decorated = new ArrayList(size); - for (Object value : list) { - PyObject pyvalue = Py.java2py(value); - if (key == null || key == Py.None) { - decorated.add(new KV(pyvalue, value)); - } - else { - decorated.add(new KV(key.__call__(pyvalue), value)); - } - } - // we will rebuild the list from the decorated version - list.clear(); - KVComparator c = new KVComparator(cmp); - if (reverse) { - Collections.reverse(decorated); // maintain stability of sort by reversing first - } - Collections.sort(decorated, c); - if (reverse) { - Collections.reverse(decorated); - } - boolean modified = list.size() > 0; - for (KV kv : decorated) { - list.add(kv.value); - } - if (modified) { - throw Py.ValueError("list modified during sort"); - } - } - - private static class MapMethod extends PyBuiltinMethodNarrow { - protected MapMethod(String name, int numArgs) { - super(name, numArgs); - } - - protected MapMethod(String name, int minArgs, int maxArgs) { - super(name, minArgs, maxArgs); - } - - protected Map asMap(){ - return (Map)self.getJavaProxy(); - } - } - - private static class MapClassMethod extends PyBuiltinClassMethodNarrow { - protected MapClassMethod(String name, int minArgs, int maxArgs) { - super(name, minArgs, maxArgs); - } - - protected Class asClass() { - return (Class) self.getJavaProxy(); - } - } - private static abstract class ComparableMethod extends PyBuiltinMethodNarrow { protected ComparableMethod(String name, int numArgs) { super(name, numArgs); @@ -1063,59 +896,27 @@ protected abstract boolean getResult(int comparison); } - private static PyObject mapEq(PyObject self, PyObject other) { - Map selfMap = ((Map) self.getJavaProxy()); - if (other.getType().isSubType(PyDictionary.TYPE)) { - PyDictionary oDict = (PyDictionary) other; - if (selfMap.size() != oDict.size()) { - return Py.False; - } - for (Object jkey : selfMap.keySet()) { - Object jval = selfMap.get(jkey); - PyObject oVal = oDict.__finditem__(Py.java2py(jkey)); - if (oVal == null) { - return Py.False; - } - if (!Py.java2py(jval)._eq(oVal).__nonzero__()) { - return Py.False; - } - } - return Py.True; - } else { - Object oj = other.getJavaProxy(); - if (oj instanceof Map) { - Map oMap = (Map) oj; - return Py.newBoolean(selfMap.equals(oMap)); - } else { - return null; - } + private static class CollectionProxies { + final Map, PyBuiltinMethod[]> proxies; + final Map, PyBuiltinMethod[]> postProxies; + + CollectionProxies() { + proxies = buildCollectionProxies(); + postProxies = buildPostCollectionProxies(); } } - // Map ordering comparisons (lt, le, gt, ge) are based on the key sets; - // we just define mapLe + mapEq for total ordering of such key sets - private static PyObject mapLe(PyObject self, PyObject other) { - Set selfKeys = ((Map) self.getJavaProxy()).keySet(); - if (other.getType().isSubType(PyDictionary.TYPE)) { - PyDictionary oDict = (PyDictionary) other; - for (Object jkey : selfKeys) { - if (!oDict.__contains__(Py.java2py(jkey))) { - return Py.False; - } - } - return Py.True; - } else { - Object oj = other.getJavaProxy(); - if (oj instanceof Map) { - Map oMap = (Map) oj; - return Py.newBoolean(oMap.keySet().containsAll(selfKeys)); - } else { - return null; - } - } + private static class CollectionsProxiesHolder { + static final CollectionProxies proxies = new CollectionProxies(); } + private static Map, PyBuiltinMethod[]> getCollectionProxies() { + return CollectionsProxiesHolder.proxies.proxies; + } + private static Map, PyBuiltinMethod[]> getPostCollectionProxies() { + return CollectionsProxiesHolder.proxies.postProxies; + } /** * Build a map of common Java collection base types (Map, Iterable, etc) that need to be @@ -1124,914 +925,73 @@ * @return A map whose key is the base Java collection types and whose entry is a list of * injected methods. */ - private static Map, PyBuiltinMethod[]> getCollectionProxies() { - if (collectionProxies == null) { - collectionProxies = Generic.map(); - postCollectionProxies = Generic.map(); + private static Map, PyBuiltinMethod[]> buildCollectionProxies() { + final Map, PyBuiltinMethod[]> proxies = new HashMap(); - PyBuiltinMethodNarrow iterableProxy = new PyBuiltinMethodNarrow("__iter__") { - @Override - public PyObject __call__() { - return new IteratorIter(((Iterable)self.getJavaProxy())); - } - }; - collectionProxies.put(Iterable.class, new PyBuiltinMethod[] {iterableProxy}); + PyBuiltinMethodNarrow iterableProxy = new PyBuiltinMethodNarrow("__iter__") { + @Override + public PyObject __call__() { + return new JavaIterator(((Iterable) self.getJavaProxy())); + } + }; + proxies.put(Iterable.class, new PyBuiltinMethod[]{iterableProxy}); - PyBuiltinMethodNarrow lenProxy = new PyBuiltinMethodNarrow("__len__") { - @Override - public PyObject __call__() { - return Py.newInteger(((Collection)self.getJavaProxy()).size()); - } - }; + PyBuiltinMethodNarrow lenProxy = new PyBuiltinMethodNarrow("__len__") { + @Override + public PyObject __call__() { + return Py.newInteger(((Collection) self.getJavaProxy()).size()); + } + }; - PyBuiltinMethodNarrow containsProxy = new PyBuiltinMethodNarrow("__contains__", 1) { - @Override - public PyObject __call__(PyObject obj) { - boolean contained = false; - Object proxy = obj.getJavaProxy(); - if (proxy == null) { - for (Object item : (Collection)self.getJavaProxy()) { - if (Py.java2py(item)._eq(obj).__nonzero__()) { - contained = true; - break; - } - } - } else { - Object other = obj.__tojava__(Object.class); - contained = ((Collection)self.getJavaProxy()).contains(other); - - } - return contained ? Py.True : Py.False; - } - }; - collectionProxies.put(Collection.class, new PyBuiltinMethod[] {lenProxy, - containsProxy}); - - PyBuiltinMethodNarrow iteratorProxy = new PyBuiltinMethodNarrow("__iter__") { - @Override - public PyObject __call__() { - return new IteratorIter(((Iterator)self.getJavaProxy())); - } - }; - collectionProxies.put(Iterator.class, new PyBuiltinMethod[] {iteratorProxy}); - - PyBuiltinMethodNarrow enumerationProxy = new PyBuiltinMethodNarrow("__iter__") { - @Override - public PyObject __call__() { - return new EnumerationIter(((Enumeration)self.getJavaProxy())); - } - }; - collectionProxies.put(Enumeration.class, new PyBuiltinMethod[] {enumerationProxy}); - - // Map doesn't extend Collection, so it needs its own version of len, iter and contains - PyBuiltinMethodNarrow mapLenProxy = new MapMethod("__len__", 0) { - @Override - public PyObject __call__() { - return Py.java2py(asMap().size()); - } - }; - PyBuiltinMethodNarrow mapReprProxy = new MapMethod("__repr__", 0) { - @Override - public PyObject __call__() { - StringBuilder repr = new StringBuilder("{"); - for (Map.Entry entry : asMap().entrySet()) { - Object jkey = entry.getKey(); - Object jval = entry.getValue(); - repr.append(jkey.toString()); - repr.append(": "); - repr.append(jval == asMap() ? "{...}" : (jval == null ? "None" : jval.toString())); - repr.append(", "); - } - int lastindex = repr.lastIndexOf(", "); - if (lastindex > -1) { - repr.delete(lastindex, lastindex + 2); - } - repr.append("}"); - return new PyString(repr.toString()); - } - }; - PyBuiltinMethodNarrow mapEqProxy = new MapMethod("__eq__", 1) { - @Override - public PyObject __call__(PyObject other) { - return mapEq(self, other); - } - }; - PyBuiltinMethodNarrow mapLeProxy = new MapMethod("__le__", 1) { - @Override - public PyObject __call__(PyObject other) { - return mapLe(self, other); - } - }; - PyBuiltinMethodNarrow mapGeProxy = new MapMethod("__ge__", 1) { - @Override - public PyObject __call__(PyObject other) { - return (mapLe(self, other).__not__()).__or__(mapEq(self, other)); - } - }; - PyBuiltinMethodNarrow mapLtProxy = new MapMethod("__lt__", 1) { - @Override - public PyObject __call__(PyObject other) { - return mapLe(self, other).__and__(mapEq(self, other).__not__()); - } - }; - PyBuiltinMethodNarrow mapGtProxy = new MapMethod("__gt__", 1) { - @Override - public PyObject __call__(PyObject other) { - return mapLe(self, other).__not__(); - } - }; - PyBuiltinMethodNarrow mapIterProxy = new MapMethod("__iter__", 0) { - @Override - public PyObject __call__() { - return new IteratorIter(asMap().keySet()); - } - }; - PyBuiltinMethodNarrow mapContainsProxy = new MapMethod("__contains__", 1) { - @Override - public PyObject __call__(PyObject obj) { - Object other = obj.__tojava__(Object.class); - return asMap().containsKey(other) ? Py.True : Py.False; - } - }; - // "get" needs to override java.util.Map#get() in its subclasses, too, so this needs to be injected last - // (i.e. when HashMap is loaded not when it is recursively loading its super-type Map) - PyBuiltinMethodNarrow mapGetProxy = new MapMethod("get", 1, 2) { - @Override - public PyObject __call__(PyObject key) { - return __call__(key, Py.None); - } - @Override - public PyObject __call__(PyObject key, PyObject _default) { - Object jkey = Py.tojava(key, Object.class); - if (asMap().containsKey(jkey)) { - return Py.java2py(asMap().get(jkey)); - } else { - return _default; - } - } - }; - PyBuiltinMethodNarrow mapGetItemProxy = new MapMethod("__getitem__", 1) { - @Override - public PyObject __call__(PyObject key) { - Object jkey = Py.tojava(key, Object.class); - if (asMap().containsKey(jkey)) { - return Py.java2py(asMap().get(jkey)); - } else { - throw Py.KeyError(key); - } - } - }; - PyBuiltinMethodNarrow mapPutProxy = new MapMethod("__setitem__", 2) { - @Override - public PyObject __call__(PyObject key, PyObject value) { - asMap().put(Py.tojava(key, Object.class), - value == Py.None ? Py.None : Py.tojava(value, Object.class)); - return Py.None; - } - }; - PyBuiltinMethodNarrow mapRemoveProxy = new MapMethod("__delitem__", 1) { - @Override - public PyObject __call__(PyObject key) { - Object jkey = Py.tojava(key, Object.class); - if (asMap().remove(jkey) == null) { - throw Py.KeyError(key); - } - return Py.None; - } - }; - PyBuiltinMethodNarrow mapIterItemsProxy = new MapMethod("iteritems", 0) { - @Override - public PyObject __call__() { - final Iterator> entrySetIterator = asMap().entrySet().iterator(); - return new PyIterator() { - @Override - public PyObject __iternext__() { - if (entrySetIterator.hasNext()) { - Map.Entry nextEntry = entrySetIterator.next(); - // yield a Python tuple object (key, value) - return new PyTuple(Py.java2py(nextEntry.getKey()), - Py.java2py(nextEntry.getValue())); - } - return null; - } - }; - } - }; - PyBuiltinMethodNarrow mapHasKeyProxy = new MapMethod("has_key", 1) { - @Override - public PyObject __call__(PyObject key) { - return asMap().containsKey(Py.tojava(key, Object.class)) ? Py.True : Py.False; - } - }; - PyBuiltinMethodNarrow mapKeysProxy = new MapMethod("keys", 0) { - @Override - public PyObject __call__() { - PyList keys = new PyList(); - for (Object key : asMap().keySet()) { - keys.add(Py.java2py(key)); - } - return keys; - } - }; - PyBuiltinMethod mapValuesProxy = new MapMethod("values", 0) { - @Override - public PyObject __call__() { - PyList values = new PyList(); - for (Object value : asMap().values()) { - values.add(Py.java2py(value)); - } - return values; - } - }; - PyBuiltinMethodNarrow mapSetDefaultProxy = new MapMethod("setdefault", 1, 2) { - @Override - public PyObject __call__(PyObject key) { - return __call__(key, Py.None); - } - @Override - public PyObject __call__(PyObject key, PyObject _default) { - Object jkey = Py.tojava(key, Object.class); - Object jval = asMap().get(jkey); - if (jval == null) { - asMap().put(jkey, _default == Py.None? Py.None : Py.tojava(_default, Object.class)); - return _default; - } - return Py.java2py(jval); - } - }; - PyBuiltinMethodNarrow mapPopProxy = new MapMethod("pop", 1, 2) { - @Override - public PyObject __call__(PyObject key) { - return __call__(key, null); - } - @Override - public PyObject __call__(PyObject key, PyObject _default) { - Object jkey = Py.tojava(key, Object.class); - if (asMap().containsKey(jkey)) { - PyObject value = Py.java2py(asMap().remove(jkey)); - assert (value != null); - return Py.java2py(value); - } else { - if (_default == null) { - throw Py.KeyError(key); - } - return _default; - } - } - }; - PyBuiltinMethodNarrow mapPopItemProxy = new MapMethod("popitem", 0) { - @Override - public PyObject __call__() { - if (asMap().size() == 0) { - throw Py.KeyError("popitem(): map is empty"); - } - Object key = asMap().keySet().toArray()[0]; - Object val = asMap().remove(key); - return Py.java2py(val); - } - }; - PyBuiltinMethodNarrow mapItemsProxy = new MapMethod("items", 0) { - @Override - public PyObject __call__() { - PyList items = new PyList(); - for (Map.Entry entry : asMap().entrySet()) { - items.add(new PyTuple(Py.java2py(entry.getKey()), - Py.java2py(entry.getValue()))); - } - return items; - } - }; - PyBuiltinMethodNarrow mapCopyProxy = new MapMethod("copy", 0) { - @Override - public PyObject __call__() { - Map jmap = asMap(); - Map jclone; - try { - jclone = (Map) jmap.getClass().newInstance(); - } catch (IllegalAccessException e) { - throw Py.JavaError(e); - } catch (InstantiationException e) { - throw Py.JavaError(e); - } - for (Map.Entry entry : jmap.entrySet()) { - jclone.put(entry.getKey(), entry.getValue()); - } - return Py.java2py(jclone); - } - }; - PyBuiltinMethodNarrow mapUpdateProxy = new MapMethod("update", 0, 1) { - private Map jmap; - @Override - public PyObject __call__() { - return Py.None; - } - @Override - public PyObject __call__(PyObject other) { - // `other` is either another dict-like object, or an iterable of key/value pairs (as tuples - // or other iterables of length two) - return __call__(new PyObject[]{other}, new String[]{}); - } - @Override - public PyObject __call__(PyObject[] args, String[] keywords) { - if ((args.length - keywords.length) != 1) { - throw info.unexpectedCall(args.length, false); - } - jmap = asMap(); - PyObject other = args[0]; - // update with entries from `other` (adapted from their equivalent in PyDictionary#update) - Object proxy = other.getJavaProxy(); - if (proxy instanceof Map) { - merge((Map)proxy); - } - else if (other.__findattr__("keys") != null) { - merge(other); - } else { - mergeFromSeq(other); - } - // update with entries from keyword arguments - for (int i = 0; i < keywords.length; i++) { - String jkey = keywords[i]; - PyObject value = args[1+i]; - jmap.put(jkey, Py.tojava(value, Object.class)); - } - return Py.None; - } - private void merge(Map other) { - for (Map.Entry entry : other.entrySet()) { - jmap.put(entry.getKey(), entry.getValue()); - } - } - private void merge(PyObject other) { - if (other instanceof PyDictionary) { - jmap.putAll(((PyDictionary) other).getMap()); - } else if (other instanceof PyStringMap) { - mergeFromKeys(other, ((PyStringMap)other).keys()); - } else { - mergeFromKeys(other, other.invoke("keys")); - } - } - private void mergeFromKeys(PyObject other, PyObject keys) { - for (PyObject key : keys.asIterable()) { - jmap.put(Py.tojava(key, Object.class), - Py.tojava(other.__getitem__(key), Object.class)); - } - } - private void mergeFromSeq(PyObject other) { - PyObject pairs = other.__iter__(); - PyObject pair; - - for (int i = 0; (pair = pairs.__iternext__()) != null; i++) { - try { - pair = PySequence.fastSequence(pair, ""); - } catch(PyException pye) { - if (pye.match(Py.TypeError)) { - throw Py.TypeError(String.format("cannot convert dictionary update sequence " - + "element #%d to a sequence", i)); - } - throw pye; - } - int n; - if ((n = pair.__len__()) != 2) { - throw Py.ValueError(String.format("dictionary update sequence element #%d " - + "has length %d; 2 is required", i, n)); - } - jmap.put(Py.tojava(pair.__getitem__(0), Object.class), - Py.tojava(pair.__getitem__(1), Object.class)); - } - } - }; - PyBuiltinClassMethodNarrow mapFromKeysProxy = new MapClassMethod("fromkeys", 1, 2) { - @Override - public PyObject __call__(PyObject keys) { - return __call__(keys, null); - } - @Override - public PyObject __call__(PyObject keys, PyObject _default) { - Object defobj = _default == null ? Py.None : Py.tojava(_default, Object.class); - Class theClass = asClass(); - try { - // always injected to java.util.Map, so we know the class object we get from asClass is subtype of java.util.Map - Map theMap = (Map) theClass.newInstance(); - for (PyObject key : keys.asIterable()) { - theMap.put(Py.tojava(key, Object.class), defobj); - } - return Py.java2py(theMap); - } catch (InstantiationException e) { - throw Py.JavaError(e); - } catch (IllegalAccessException e) { - throw Py.JavaError(e); - } - } - }; - collectionProxies.put(Map.class, new PyBuiltinMethod[] { - mapLenProxy, - // map IterProxy can conflict with Iterable.class; - // fix after the fact in handleMroError - mapIterProxy, - mapReprProxy, - mapEqProxy, - mapLeProxy, - mapLtProxy, - mapGeProxy, - mapGtProxy, - mapContainsProxy, - mapGetItemProxy, - //mapGetProxy, - mapPutProxy, - mapRemoveProxy, - mapIterItemsProxy, - mapHasKeyProxy, - mapKeysProxy, - //mapValuesProxy, - mapSetDefaultProxy, - mapPopProxy, - mapPopItemProxy, - mapItemsProxy, - mapCopyProxy, - mapUpdateProxy, - mapFromKeysProxy}); // class method - postCollectionProxies.put(Map.class, new PyBuiltinMethod[] {mapGetProxy, - mapValuesProxy}); - - PyBuiltinMethodNarrow listGetProxy = new ListMethod("__getitem__", 1) { - @Override - public PyObject __call__(PyObject key) { - return new ListIndexDelegate(asList()).checkIdxAndGetItem(key); - } - }; - PyBuiltinMethodNarrow listSetProxy = new ListMethod("__setitem__", 2) { - @Override - public PyObject __call__(PyObject key, PyObject value) { - new ListIndexDelegate(asList()).checkIdxAndSetItem(key, value); - return Py.None; - } - }; - PyBuiltinMethodNarrow listRemoveProxy = new ListMethod("__delitem__", 1) { - @Override - public PyObject __call__(PyObject key) { - new ListIndexDelegate(asList()).checkIdxAndDelItem(key); - return Py.None; - } - }; - PyBuiltinMethodNarrow listEqProxy = new ListMethod("__eq__", 1) { - @Override - public PyObject __call__(PyObject other) { - List jList = asList(); - if (other.getType().isSubType(PyList.TYPE)) { - PyList oList = (PyList) other; - if (jList.size() != oList.size()) { - return Py.False; - } - for (int i = 0; i < jList.size(); i++) { - if (!Py.java2py(jList.get(i))._eq(oList.pyget(i)).__nonzero__()) { - return Py.False; - } - } - return Py.True; - } else { - Object oj = other.getJavaProxy(); - if (oj instanceof List) { - List oList = (List) oj; - if (jList.size() != oList.size()) { - return Py.False; - } - for (int i = 0; i < jList.size(); i++) { - if (!Py.java2py(jList.get(i))._eq( - Py.java2py(oList.get(i))).__nonzero__()) { - return Py.False; - } - } - return Py.True; - } else { - return null; + PyBuiltinMethodNarrow containsProxy = new PyBuiltinMethodNarrow("__contains__", 1) { + @Override + public PyObject __call__(PyObject obj) { + boolean contained = false; + Object proxy = obj.getJavaProxy(); + if (proxy == null) { + for (Object item : (Collection) self.getJavaProxy()) { + if (Py.java2py(item)._eq(obj).__nonzero__()) { + contained = true; + break; } } + } else { + Object other = obj.__tojava__(Object.class); + contained = ((Collection) self.getJavaProxy()).contains(other); + } - }; - PyBuiltinMethodNarrow listAppendProxy = new ListMethod("append", 1) { - @Override - public PyObject __call__(PyObject value) { - asList().add(value); - return Py.None; - } - }; - PyBuiltinMethodNarrow listExtendProxy = new ListMethod("extend", 1) { - @Override - public PyObject __call__(PyObject obj) { - List jList = asList(); - List extension = new ArrayList(); + return contained ? Py.True : Py.False; + } + }; + proxies.put(Collection.class, new PyBuiltinMethod[]{lenProxy, + containsProxy}); - // Extra step to build the extension list is necessary - // in case of adding to oneself - for (PyObject item : obj.asIterable()) { - extension.add(item); - } - jList.addAll(extension); - return Py.None; - } - }; - PyBuiltinMethodNarrow listInsertProxy = new ListMethod("insert", 2) { - @Override - public PyObject __call__(PyObject index, PyObject object) { - List jlist = asList(); - ListIndexDelegate lid = new ListIndexDelegate(jlist); - int idx = lid.fixBoundIndex(index); - jlist.add(idx, object); - return Py.None; - } - }; - PyBuiltinMethodNarrow listPopProxy = new ListMethod("pop", 0, 1) { - @Override - public PyObject __call__() { - return __call__(Py.newInteger(-1)); - } - @Override - public PyObject __call__(PyObject index) { - List jlist = asList(); - if (jlist.isEmpty()) { - throw Py.IndexError("pop from empty list"); - } - ListIndexDelegate ldel = new ListIndexDelegate(jlist); - PyObject item = ldel.checkIdxAndFindItem(index.asInt()); - if (item == null) { - throw Py.IndexError("pop index out of range"); - } else { - ldel.checkIdxAndDelItem(index); - return item; - } - } - }; - PyBuiltinMethodNarrow listIndexProxy = new ListMethod("index", 1, 3) { - @Override - public PyObject __call__(PyObject object) { - return __call__(object, Py.newInteger(0), Py.newInteger(asList().size())); - } - @Override - public PyObject __call__(PyObject object, PyObject start) { - return __call__(object, start, Py.newInteger(asList().size())); - } - @Override - public PyObject __call__(PyObject object, PyObject start, PyObject end) { - List jlist = asList(); - ListIndexDelegate lid = new ListIndexDelegate(jlist); - int start_index = lid.fixBoundIndex(start); - int end_index = lid.fixBoundIndex(end); - int i = start_index; - try { - for (ListIterator it = jlist.listIterator(start_index); it.hasNext(); i++) { - if (i == end_index) { - break; - } - Object jobj = it.next(); - if (Py.java2py(jobj)._eq(object).__nonzero__()) { - return Py.newInteger(i); - } - } - } catch (ConcurrentModificationException e) { - throw Py.ValueError(object.toString() + " is not in list"); - } - throw Py.ValueError(object.toString() + " is not in list"); - } - }; - PyBuiltinMethodNarrow listCountProxy = new ListMethod("count", 1) { - @Override - public PyObject __call__(PyObject object) { - int count = 0; - List jlist = asList(); - for (int i = 0; i < jlist.size(); i++) { - Object jobj = jlist.get(i); - if (Py.java2py(jobj)._eq(object).__nonzero__()) { - ++count; - } - } - return Py.newInteger(count); - } - }; - PyBuiltinMethodNarrow listReverseProxy = new ListMethod("reverse", 0) { - @Override - public PyObject __call__() { - List jlist = asList(); - Collections.reverse(jlist); - return Py.None; - } - }; - PyBuiltinMethodNarrow listRemoveOverrideProxy = new ListMethod("remove", 1) { - @Override - public PyObject __call__(PyObject object) { - List jlist = asList(); - for (int i = 0; i < jlist.size(); i++) { - Object jobj = jlist.get(i); - if (Py.java2py(jobj)._eq(object).__nonzero__()) { - jlist.remove(i); - return Py.None; - } - } - throw Py.ValueError(object.toString() + " is not in list"); - } - }; - PyBuiltinMethodNarrow listRAddProxy = new ListMethod("__radd__", 1) { - @Override - public PyObject __call__(PyObject obj) { - // first, clone the self list - List jList = asList(); - List jClone; - try { - jClone = (List) jList.getClass().newInstance(); - } catch (IllegalAccessException e) { - throw Py.JavaError(e); - } catch (InstantiationException e) { - throw Py.JavaError(e); - } - for (Object entry : jList) { - jClone.add(entry); - } + PyBuiltinMethodNarrow iteratorProxy = new PyBuiltinMethodNarrow("__iter__") { + @Override + public PyObject __call__() { + return new JavaIterator(((Iterator) self.getJavaProxy())); + } + }; + proxies.put(Iterator.class, new PyBuiltinMethod[]{iteratorProxy}); - // then, extend it with elements from the other list - // (but, since this is reverse add, we are technically - // pre-pending the clone with elements from the other list) - if (obj instanceof Collection) { - jClone.addAll(0, (Collection) obj); - } else { - int i = 0; - for (PyObject item : obj.asIterable()) { - jClone.add(i, item); - i++; - } - } - - return Py.java2py(jClone); - } - }; - PyBuiltinMethodNarrow listIAddProxy = new ListMethod("__iadd__", 1) { - @Override - public PyObject __call__(PyObject obj) { - List jList = asList(); - if (obj instanceof Collection) { - jList.addAll((Collection) obj); - } else { - for (PyObject item : obj.asIterable()) { - jList.add(item); - } - } - return self; - } - }; - PyBuiltinMethodNarrow listIMulProxy = new ListMethod("__imul__", 1) { - @Override - public PyObject __call__(PyObject obj) { - List jList = asList(); - int mult = obj.asInt(); - - // anything below 0 multiplier, we clear the list - if (mult <= 0) { - jList.clear(); - } else { - try { - if (jList instanceof ArrayList) { - ((ArrayList)jList).ensureCapacity(jList.size() * (mult - 1)); - } - // otherwise, extend it (in-place) x times, where x is int-cast from obj - int originalSize = jList.size(); - for (mult = mult - 1; mult > 0; mult--) { - for (int i = 0; i < originalSize; i++) { - jList.add(jList.get(i)); - } - } - } catch (OutOfMemoryError t) { - throw Py.MemoryError(""); - } - } - return self; - } - }; - PyBuiltinMethodNarrow listSortProxy = new ListMethod("sort", 0, 3) { - @Override - public PyObject __call__() { - list_sort(asList(), Py.None, Py.None, false); - return Py.None; - } - - @Override - public PyObject __call__(PyObject cmp) { - list_sort(asList(), cmp, Py.None, false); - return Py.None; - } - - @Override - public PyObject __call__(PyObject cmp, PyObject key) { - list_sort(asList(), cmp, key, false); - return Py.None; - } - - @Override - public PyObject __call__(PyObject cmp, PyObject key, PyObject reverse) { - list_sort(asList(), cmp, key, reverse.__nonzero__()); - return Py.None; - } - - @Override - public PyObject __call__(PyObject[] args, String[] kwds) { - ArgParser ap = new ArgParser("list", args, kwds, new String[]{ - "cmp", "key", "reverse"}, 0); - PyObject cmp = ap.getPyObject(0, Py.None); - PyObject key = ap.getPyObject(1, Py.None); - PyObject reverse = ap.getPyObject(2, Py.False); - list_sort(asList(), cmp, key, reverse.__nonzero__()); - return Py.None; - } - }; - - collectionProxies.put(List.class, new PyBuiltinMethod[] { - listGetProxy, - listSetProxy, - listEqProxy, - listRemoveProxy, - listAppendProxy, - listExtendProxy, - listInsertProxy, - listPopProxy, - listIndexProxy, - listCountProxy, - listReverseProxy, - listRAddProxy, - listIAddProxy, - new ListMulProxyClass("__mul__", 1), - new ListMulProxyClass("__rmul__", 1), - listIMulProxy, - listSortProxy, - }); - postCollectionProxies.put(List.class, new PyBuiltinMethod[]{ - listRemoveOverrideProxy, - }); - } - return collectionProxies; + PyBuiltinMethodNarrow enumerationProxy = new PyBuiltinMethodNarrow("__iter__") { + @Override + public PyObject __call__() { + return new EnumerationIter(((Enumeration) self.getJavaProxy())); + } + }; + proxies.put(Enumeration.class, new PyBuiltinMethod[]{enumerationProxy}); + proxies.put(List.class, JavaProxyList.getProxyMethods()); + proxies.put(Map.class, JavaProxyMap.getProxyMethods()); + proxies.put(Set.class, JavaProxySet.getProxyMethods()); + return Collections.unmodifiableMap(proxies); } - private static Map, PyBuiltinMethod[]> getPostCollectionProxies() { - getCollectionProxies(); - assert (postCollectionProxies != null); - return postCollectionProxies; - } - - - protected static class ListIndexDelegate extends SequenceIndexDelegate { - - private final List list; - - public ListIndexDelegate(List list) { - this.list = list; - } - @Override - public void delItem(int idx) { - list.remove(idx); - } - - @Override - public PyObject getItem(int idx) { - return Py.java2py(list.get(idx)); - } - - @Override - public PyObject getSlice(int start, int stop, int step) { - if (step > 0 && stop < start) { - stop = start; - } - int n = PySequence.sliceLength(start, stop, step); - List newList; - try { - newList = list.getClass().newInstance(); - } catch (Exception e) { - throw Py.JavaError(e); - } - int j = 0; - for (int i = start; j < n; i += step) { - newList.add(list.get(i)); - j++; - } - return Py.java2py(newList); - } - - @Override - public String getTypeName() { - return list.getClass().getName(); - } - - @Override - public int len() { - return list.size(); - } - - protected int fixBoundIndex(PyObject index) { - PyInteger length = Py.newInteger(len()); - if (index._lt(Py.Zero).__nonzero__()) { - index = index._add(length); - if (index._lt(Py.Zero).__nonzero__()) { - index = Py.Zero; - } - } else if (index._gt(length).__nonzero__()) { - index = length; - } - int i = index.asIndex(); - assert i >= 0; - return i; - } - - @Override - public void setItem(int idx, PyObject value) { - list.set(idx, value.__tojava__(Object.class)); - } - - @Override - public void setSlice(int start, int stop, int step, PyObject value) { - if (stop < start) { - stop = start; - } - if (value.javaProxy == this.list) { - List xs = Generic.list(); - xs.addAll(this.list); - setsliceList(start, stop, step, xs); - } else if (value instanceof PyList) { - setslicePyList(start, stop, step, (PyList)value); - } else { - Object valueList = value.__tojava__(List.class); - if (valueList != null && valueList != Py.NoConversion) { - setsliceList(start, stop, step, (List)valueList); - } else { - setsliceIterator(start, stop, step, value.asIterable().iterator()); - } - } - } - - final private void setsliceList(int start, int stop, int step, List value) { - if (step == 1) { - list.subList(start, stop).clear(); - list.addAll(start, value); - } else { - int size = list.size(); - Iterator iter = value.listIterator(); - for (int j = start; iter.hasNext(); j += step) { - Object item =iter.next(); - if (j >= size) { - list.add(item); - } else { - list.set(j, item); - } - } - } - } - - final private void setsliceIterator(int start, int stop, int step, Iterator iter) { - if (step == 1) { - List insertion = new ArrayList(); - if (iter != null) { - while (iter.hasNext()) { - insertion.add(iter.next().__tojava__(Object.class)); - } - } - list.subList(start, stop).clear(); - list.addAll(start, insertion); - } else { - int size = list.size(); - for (int j = start; iter.hasNext(); j += step) { - Object item = iter.next().__tojava__(Object.class); - if (j >= size) { - list.add(item); - } else { - list.set(j, item); - } - } - } - } - - final private void setslicePyList(int start, int stop, int step, PyList value) { - if (step == 1) { - list.subList(start, stop).clear(); - int n = value.getList().size(); - for (int i=0, j=start; i iter = value.getList().listIterator(); - for (int j = start; iter.hasNext(); j += step) { - Object item = iter.next().__tojava__(Object.class); - if (j >= size) { - list.add(item); - } else { - list.set(j, item); - } - } - } - } - - - @Override - public void delItems(int start, int stop) { - int n = stop - start; - while (n-- > 0) { - delItem(start); - } - } + private static Map, PyBuiltinMethod[]> buildPostCollectionProxies() { + final Map, PyBuiltinMethod[]> postProxies = new HashMap(); + postProxies.put(List.class, JavaProxyList.getPostProxyMethods()); + postProxies.put(Map.class, JavaProxyMap.getPostProxyMethods()); + postProxies.put(Set.class, JavaProxySet.getPostProxyMethods()); + return Collections.unmodifiableMap(postProxies); } } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Mon Dec 29 08:00:07 2014 From: jython-checkins at python.org (jim.baker) Date: Mon, 29 Dec 2014 07:00:07 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_select=2Eselect_and_related?= =?utf-8?q?_socket=2Econnect=5Fex_fixes?= Message-ID: <20141229070005.125143.85430@psf.io> https://hg.python.org/jython/rev/3c4adde9c083 changeset: 7475:3c4adde9c083 user: Jim Baker date: Mon Dec 29 00:00:00 2014 -0700 summary: select.select and related socket.connect_ex fixes Fixed child socket handling such that if the parent socket is blocking, the child socket will now immediately complete its setup. Fixed socket.connect_ex so that it depends on the underlying connect future, as well as catching any socket.error exceptions and returning the corresponding errno. This change enables test_select_new. Fixes http://bugs.jython.org/issue2242 files: Lib/_socket.py | 24 ++++++++++++++++-------- Lib/test/regrtest.py | 1 - Lib/test/test_select_new.py | 18 +++++++++--------- Lib/test/test_socket.py | 15 ++++++++++++++- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/Lib/_socket.py b/Lib/_socket.py --- a/Lib/_socket.py +++ b/Lib/_socket.py @@ -615,6 +615,8 @@ # thread pool child_channel.closeFuture().addListener(unlatch_child) + if self.parent_socket.timeout is None: + child._ensure_post_connect() child._wait_on_latch() log.debug("Socket initChannel completed waiting on latch", extra={"sock": child}) @@ -842,8 +844,8 @@ log.debug("Connect to %s", addr, extra={"sock": self}) self.channel = bootstrap.channel() - connect_future = self.channel.connect(addr) - self._handle_channel_future(connect_future, "connect") + self.connect_future = self.channel.connect(addr) + self._handle_channel_future(self.connect_future, "connect") self.bind_timestamp = time.time() def _post_connect(self): @@ -871,12 +873,17 @@ log.debug("Completed connection to %s", addr, extra={"sock": self}) def connect_ex(self, addr): - self.connect(addr) - if self.timeout is None: + if not self.connected: + try: + self.connect(addr) + except error as e: + return e.errno + if not self.connect_future.isDone(): + return errno.EINPROGRESS + elif self.connect_future.isSuccess(): return errno.EISCONN else: - return errno.EINPROGRESS - + return errno.ENOTCONN # SERVER METHODS # Calling listen means this is a server socket @@ -1033,11 +1040,11 @@ pass # already removed, can safely ignore (presumably) if how & SHUT_WR: self._can_write = False - + def _readable(self): if self.socket_type == CLIENT_SOCKET or self.socket_type == DATAGRAM_SOCKET: log.debug("Incoming head=%s queue=%s", self.incoming_head, self.incoming, extra={"sock": self}) - return ( + return bool( (self.incoming_head is not None and self.incoming_head.readableBytes()) or self.incoming.peek()) elif self.socket_type == SERVER_SOCKET: @@ -1338,6 +1345,7 @@ self.active = AtomicBoolean() self.active_latch = CountDownLatch(1) self.accepted = False + self.timeout = parent_socket.timeout def _ensure_post_connect(self): do_post_connect = not self.active.getAndSet(True) diff --git a/Lib/test/regrtest.py b/Lib/test/regrtest.py --- a/Lib/test/regrtest.py +++ b/Lib/test/regrtest.py @@ -1314,7 +1314,6 @@ test_peepholer test_pyclbr test_pyexpat - test_select_new test_stringprep test_threadsignals test_transformer diff --git a/Lib/test/test_select_new.py b/Lib/test/test_select_new.py --- a/Lib/test/test_select_new.py +++ b/Lib/test/test_select_new.py @@ -16,18 +16,15 @@ DATA_CHUNK = "." * DATA_CHUNK_SIZE # -# The timing of these tests depends on the how the unerlying OS socket library +# The timing of these tests depends on the how the underlying OS socket library # handles buffering. These values may need tweaking for different platforms # # The fundamental problem is that there is no reliable way to fill a socket with bytes -# +# To address this for running on Netty, we arbitrarily send 10000 bytes -if test_support.is_jython: - SELECT_TIMEOUT = 0 -else: - # zero select timeout fails these tests on cpython (on windows 2003 anyway) - SELECT_TIMEOUT = 0.001 - +# zero select timeout fails these tests on cpython (on windows 2003 anyway); +# on Jython with Netty it will result in flaky test runs +SELECT_TIMEOUT = 0.001 READ_TIMEOUT = 5 class AsynchronousServer: @@ -86,6 +83,9 @@ if self.select_writable(): bytes_sent = self.socket.send(DATA_CHUNK) total_bytes += bytes_sent + if test_support.is_jython and total_bytes > 10000: + # Netty will buffer indefinitely, so just pick an arbitrary cutoff + return total_bytes else: return total_bytes except socket.error, se: @@ -149,7 +149,7 @@ def start_connect(self): result = self.socket.connect_ex(SERVER_ADDRESS) if result == errno.EISCONN: - self.connected = 1 + self.connected = True else: assert result == errno.EINPROGRESS diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1125,6 +1125,20 @@ self.serv_conn.send(MSG) self.serv_conn.send('and ' + MSG) + def testSelect(self): + # http://bugs.jython.org/issue2242 + self.assertIs(self.cli_conn.gettimeout(), None, "Server socket is not blocking") + start_time = time.time() + r, w, x = select.select([self.cli_conn], [], [], 10) + if (time.time() - start_time) > 1: + self.fail("Child socket was not immediately available for read when set to blocking") + self.assertEqual(r[0], self.cli_conn) + self.assertEqual(self.cli_conn.recv(1024), MSG) + + def _testSelect(self): + self.serv_conn.send(MSG) + + class UDPBindTest(unittest.TestCase): HOST = HOST @@ -1396,7 +1410,6 @@ def _testRecvData(self): self.cli.connect((self.HOST, self.PORT)) self.cli.send(MSG) - #time.sleep(0.5) def testRecvNoData(self): # Testing non-blocking recv -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Tue Dec 30 07:44:34 2014 From: jython-checkins at python.org (jim.baker) Date: Tue, 30 Dec 2014 06:44:34 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Fully_proxy_java=2Eutil=2ES?= =?utf-8?q?et_objects_as_Python_sets?= Message-ID: <20141230064432.22216.48437@psf.io> https://hg.python.org/jython/rev/07959ff1d8e9 changeset: 7476:07959ff1d8e9 user: Jim Baker date: Mon Dec 29 23:44:05 2014 -0700 summary: Fully proxy java.util.Set objects as Python sets The remove method now throws a KeyError if the value to be removed is not found in the set. Before this would follow Java semantics and return a boolean if successful or not. No changes were necessary in the standard library, perhaps indicating that this is a rare assumption in user code. Note that sets implementing SortedSet do not require hashing, but instead assume that there is a total ordering function on the set membership. The current test suite for testing set functionality, test_set, as subclassed in test_set_jy makes the perfectly correct assumption for CPython that hashability is required. This should be revisited with respect to testing TreeSet and ConcurrentSkipListSet. Fixes http://bugs.jython.org/issue2241 files: Lib/test/test_set_jy.py | 50 +- src/org/python/core/BaseSet.java | 4 + src/org/python/core/JavaProxySet.java | 550 +++++++++++++- 3 files changed, 583 insertions(+), 21 deletions(-) diff --git a/Lib/test/test_set_jy.py b/Lib/test/test_set_jy.py --- a/Lib/test/test_set_jy.py +++ b/Lib/test/test_set_jy.py @@ -1,13 +1,14 @@ import unittest -from test import test_support +from test import test_support, test_set +import pickle import threading -if test_support.is_jython: - from java.io import (ByteArrayInputStream, ByteArrayOutputStream, - ObjectInputStream, ObjectOutputStream) - from java.util import Random - from javatests import PySetInJavaTest +from java.io import (ByteArrayInputStream, ByteArrayOutputStream, + ObjectInputStream, ObjectOutputStream) +from java.util import Random, HashSet, LinkedHashSet +from javatests import PySetInJavaTest + class SetTestCase(unittest.TestCase): @@ -81,10 +82,41 @@ unserializer = ObjectInputStream(input) self.assertEqual(s, unserializer.readObject()) + +class TestJavaSet(test_set.TestSet): + thetype = HashSet + + def test_init(self): + # Instances of Java types cannot be re-initialized + pass + + def test_cyclical_repr(self): + pass + + def test_cyclical_print(self): + pass + + def test_pickling(self): + for i in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(self.s, i) + dup = pickle.loads(p) + self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup)) + + +class TestJavaHashSet(TestJavaSet): + thetype = HashSet + +class TestJavaLinkedHashSet(TestJavaSet): + thetype = LinkedHashSet + + def test_main(): - tests = [SetTestCase] - if test_support.is_jython: - tests.append(SetInJavaTestCase) + tests = [ + SetTestCase, + SetInJavaTestCase, + TestJavaHashSet, + TestJavaLinkedHashSet, + ] test_support.run_unittest(*tests) diff --git a/src/org/python/core/BaseSet.java b/src/org/python/core/BaseSet.java --- a/src/org/python/core/BaseSet.java +++ b/src/org/python/core/BaseSet.java @@ -20,6 +20,10 @@ _set = set; } + public Set getSet() { + return _set; + } + protected void _update(PyObject data) { _update(_set, data); } diff --git a/src/org/python/core/JavaProxySet.java b/src/org/python/core/JavaProxySet.java --- a/src/org/python/core/JavaProxySet.java +++ b/src/org/python/core/JavaProxySet.java @@ -1,5 +1,12 @@ package org.python.core; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.NavigableSet; +import java.util.NoSuchElementException; import java.util.Set; /** Proxy objects implementing java.util.Set */ @@ -7,6 +14,7 @@ class JavaProxySet { private static class SetMethod extends PyBuiltinMethodNarrow { + protected SetMethod(String name, int numArgs) { super(name, numArgs); } @@ -15,34 +23,552 @@ super(name, minArgs, maxArgs); } - protected Set asSet(){ - return (Set)self.getJavaProxy(); + @SuppressWarnings("unchecked") + protected Set asSet() { + return (Set) self.getJavaProxy(); } - protected Set newSet() { - try { - return (Set) asSet().getClass().newInstance(); - } catch (IllegalAccessException e) { - throw Py.JavaError(e); - } catch (InstantiationException e) { - throw Py.JavaError(e); + // Unlike list and dict, set maintains the derived type for the set + // so we replicate this behavior + protected PyObject makePySet(Set newSet) { + PyObject newPySet = self.getType().__call__(); + @SuppressWarnings("unchecked") + Set jSet = ((Set) newPySet.getJavaProxy()); + jSet.addAll(newSet); + return newPySet; + } + + public boolean isEqual(PyObject other) { + Set selfSet = asSet(); + Object oj = other.getJavaProxy(); + if (oj != null && oj instanceof Set) { + @SuppressWarnings("unchecked") + Set otherSet = (Set) oj; + if (selfSet.size() != otherSet.size()) { + return false; + } + return selfSet.containsAll(otherSet); + } else if (isPySet(other)) { + Set otherPySet = ((BaseSet) other).getSet(); + if (selfSet.size() != otherPySet.size()) { + return false; + } + for (PyObject pyobj : otherPySet) { + if (!selfSet.contains(pyobj.__tojava__(Object.class))) { + return false; + } + } + return true; } + return false; + } + + public boolean isSuperset(PyObject other) { + Set selfSet = asSet(); + Object oj = other.getJavaProxy(); + if (oj != null && oj instanceof Set) { + Set otherSet = (Set) oj; + return selfSet.containsAll(otherSet); + } else if (isPySet(other)) { + Set otherPySet = ((BaseSet) other).getSet(); + for (PyObject pyobj : otherPySet) { + if (!selfSet.contains(pyobj.__tojava__(Object.class))) { + return false; + } + } + return true; + } + return false; + } + + public boolean isSubset(PyObject other) { + Set selfSet = asSet(); + Object oj = other.getJavaProxy(); + if (oj != null && oj instanceof Set) { + @SuppressWarnings("unchecked") + Set otherSet = (Set) oj; + return otherSet.containsAll(selfSet); + } else if (isPySet(other)) { + Set otherPySet = ((BaseSet) other).getSet(); + for (Object obj : selfSet) { + if (!otherPySet.contains(Py.java2py(obj))) { + return false; + } + } + return true; + } + return false; + } + + protected Set difference(Collection other) { + Set selfSet = asSet(); + Set diff = new HashSet<>(selfSet); + diff.removeAll(other); + return diff; + } + protected void differenceUpdate(Collection other) { + asSet().removeAll(other); + } + + protected Set intersect(Collection[] others) { + Set selfSet = asSet(); + Set intersection = new HashSet<>(selfSet); + for (Collection other : others) { + intersection.retainAll(other); + } + return intersection; + } + protected void intersectUpdate(Collection[] others) { + Set selfSet = asSet(); + for (Collection other : others) { + selfSet.retainAll(other); + } + } + + protected Set union(Collection other) { + Set selfSet = asSet(); + Set u = new HashSet<>(selfSet); + u.addAll(other); + return u; + } + protected void update(Collection other) { + asSet().addAll(other); + } + + protected Set symDiff(Collection other) { + Set selfSet = asSet(); + Set symDiff = new HashSet<>(selfSet); + symDiff.addAll(other); + Set intersection = new HashSet<>(selfSet); + intersection.retainAll(other); + symDiff.removeAll(intersection); + return symDiff; + } + protected void symDiffUpdate(Collection other) { + Set selfSet = asSet(); + Set intersection = new HashSet<>(selfSet); + intersection.retainAll(other); + selfSet.addAll(other); + selfSet.removeAll(intersection); } } - private static final PyBuiltinMethodNarrow setIsDisjointProxy = new SetMethod("isdisjoint", 1) { + private static class SetMethodVarargs extends SetMethod { + protected SetMethodVarargs(String name) { + super(name, 0, -1); + } + + public PyObject __call__() { + return __call__(Py.EmptyObjects); + } + + public PyObject __call__(PyObject obj) { + return __call__(new PyObject[]{obj}); + } + + public PyObject __call__(PyObject obj1, PyObject obj2) { + return __call__(new PyObject[]{obj1, obj2}); + } + + public PyObject __call__(PyObject obj1, PyObject obj2, PyObject obj3) { + return __call__(new PyObject[]{obj1, obj2, obj3}); + } + + public PyObject __call__(PyObject obj1, PyObject obj2, PyObject obj3, PyObject obj4) { + return __call__(new PyObject[]{obj1, obj2, obj3, obj4}); + } + } + + private static boolean isPySet(PyObject obj) { + PyType type = obj.getType(); + return type.isSubType(PySet.TYPE) || type.isSubType(PyFrozenSet.TYPE); + } + + private static Collection getJavaSet(PyObject self, String op, PyObject obj) { + Collection items; + if (isPySet(obj)) { + Set otherPySet = ((BaseSet)obj).getSet(); + items = new ArrayList<>(otherPySet.size()); + for (PyObject pyobj : otherPySet) { + items.add(pyobj.__tojava__(Object.class)); + } + } else { + Object oj = obj.getJavaProxy(); + if (oj instanceof Set) { + @SuppressWarnings("unchecked") + Set jSet = (Set) oj; + items = jSet; + } else { + throw Py.TypeError(String.format( + "unsupported operand type(s) for %s: '%.200s' and '%.200s'", + op, self.getType().fastGetName(), obj.getType().fastGetName())); + } + } + return items; + } + + private static Collection getJavaCollection(PyObject obj) { + Collection items; + Object oj = obj.getJavaProxy(); + if (oj != null) { + if (oj instanceof Collection) { + @SuppressWarnings("unchecked") + Collection jCollection = (Collection) oj; + items = jCollection; + } else if (oj instanceof Iterable) { + items = new HashSet<>(); + for (Object item: (Iterable) oj) { + items.add(item); + } + } else { + throw Py.TypeError(String.format("unsupported operand type(s): '%.200s'", + obj.getType().fastGetName())); + } + } else { + // This step verifies objects are hashable + items = new HashSet<>(); + for (PyObject pyobj : obj.asIterable()) { + items.add(pyobj.__tojava__(Object.class)); + } + } + return items; + } + + private static Collection[] getJavaCollections(PyObject[] objs) { + Collection[] collections = new Collection[objs.length]; + int i = 0; + for (PyObject obj : objs) { + collections[i++] = getJavaCollection(obj); + } + return collections; + } + + private static Collection getCombinedJavaCollections(PyObject[] objs) { + if (objs.length == 0) { + return Collections.emptyList(); + } + if (objs.length == 1) { + return getJavaCollection(objs[0]); + } + Set items = new HashSet<>(); + for (PyObject obj : objs) { + Object oj = obj.getJavaProxy(); + if (oj != null) { + if (oj instanceof Iterable) { + for (Object item : (Iterable) oj) { + items.add(item); + } + } else { + throw Py.TypeError(String.format("unsupported operand type(s): '%.200s'", + obj.getType().fastGetName())); + } + } else { + for (PyObject pyobj : obj.asIterable()) { + items.add(pyobj.__tojava__(Object.class)); + } + } + } + return items; + } + + private static final SetMethod cmpProxy = new SetMethod("__cmp__", 1) { + @Override + public PyObject __call__(PyObject value) { + throw Py.TypeError("cannot compare sets using cmp()"); + } + }; + private static final SetMethod eqProxy = new SetMethod("__eq__", 1) { @Override public PyObject __call__(PyObject other) { + return Py.newBoolean(isEqual(other)); + } + }; + private static final SetMethod ltProxy = new SetMethod("__lt__", 1) { + @Override + public PyObject __call__(PyObject other) { + return Py.newBoolean(!isEqual(other) && isSubset(other)); + } + }; + private static class IsSubsetMethod extends SetMethod { + // __le__, issubset + + protected IsSubsetMethod(String name) { + super(name, 1); + } + + @Override + public PyObject __call__(PyObject other) { + return Py.newBoolean(isSubset(other)); + } + } + private static class IsSupersetMethod extends SetMethod { + // __ge__, issuperset + + protected IsSupersetMethod(String name) { + super(name, 1); + } + + @Override + public PyObject __call__(PyObject other) { + return Py.newBoolean(isSuperset(other)); + } + } + private static final SetMethod gtProxy = new SetMethod("__gt__", 1) { + @Override + public PyObject __call__(PyObject other) { + return Py.newBoolean(!isEqual(other) && isSuperset(other)); + } + }; + private static final SetMethod isDisjointProxy = new SetMethod("isdisjoint", 1) { + @Override + public PyObject __call__(PyObject other) { + return Py.newBoolean(intersect(new Collection[]{getJavaCollection(other)}).size() == 0); + } + }; + + private static final SetMethod differenceProxy = new SetMethodVarargs("difference") { + @Override + public PyObject __call__(PyObject[] others) { + return makePySet(difference(getCombinedJavaCollections(others))); + } + }; + private static final SetMethod differenceUpdateProxy = new SetMethodVarargs("difference_update") { + @Override + public PyObject __call__(PyObject[] others) { + differenceUpdate(getCombinedJavaCollections(others)); + return Py.None; + } + }; + private static final SetMethod subProxy = new SetMethod("__sub__", 1) { + @Override + public PyObject __call__(PyObject other) { + return makePySet(difference(getJavaSet(self, "-", other))); + } + }; + private static final SetMethod isubProxy = new SetMethod("__isub__", 1) { + @Override + public PyObject __call__(PyObject other) { + differenceUpdate(getJavaSet(self, "-=", other)); + return self; + } + }; + + private static final SetMethod intersectionProxy = new SetMethodVarargs("intersection") { + @Override + public PyObject __call__(PyObject[] others) { + return makePySet(intersect(getJavaCollections(others))); + } + }; + private static final SetMethod intersectionUpdateProxy = new SetMethodVarargs("intersection_update") { + @Override + public PyObject __call__(PyObject[] others) { + intersectUpdate(getJavaCollections(others)); + return Py.None; + } + }; + private static final SetMethod andProxy = new SetMethod("__and__", 1) { + @Override + public PyObject __call__(PyObject other) { + return makePySet(intersect(new Collection[]{getJavaSet(self, "&", other)})); + } + }; + private static final SetMethod iandProxy = new SetMethod("__iand__", 1) { + @Override + public PyObject __call__(PyObject other) { + intersectUpdate(new Collection[]{getJavaSet(self, "&=", other)}); + return self; + } + }; + + private static final SetMethod symDiffProxy = new SetMethod("symmetric_difference", 1) { + @Override + public PyObject __call__(PyObject other) { + return makePySet(symDiff(getJavaCollection(other))); + } + }; + private static final SetMethod symDiffUpdateProxy = new SetMethod("symmetric_difference_update", 1) { + @Override + public PyObject __call__(PyObject other) { + symDiffUpdate(getJavaCollection(other)); + return Py.None; + } + }; + private static final SetMethod xorProxy = new SetMethod("__xor__", 1) { + @Override + public PyObject __call__(PyObject other) { + return makePySet(symDiff(getJavaSet(self, "^", other))); + } + }; + private static final SetMethod ixorProxy = new SetMethod("__ixor__", 1) { + @Override + public PyObject __call__(PyObject other) { + symDiffUpdate(getJavaSet(self, "^=", other)); + return self; + } + }; + + private static final SetMethod unionProxy = new SetMethodVarargs("union") { + @Override + public PyObject __call__(PyObject[] others) { + return makePySet(union(getCombinedJavaCollections(others))); + } + }; + private static final SetMethod updateProxy = new SetMethodVarargs("update") { + @Override + public PyObject __call__(PyObject[] others) { + update(getCombinedJavaCollections(others)); + return Py.None; + } + }; + private static final SetMethod orProxy = new SetMethod("__or__", 1) { + @Override + public PyObject __call__(PyObject other) { + return makePySet(union(getJavaSet(self, "|", other))); + } + }; + private static final SetMethod iorProxy = new SetMethod("__ior__", 1) { + @Override + public PyObject __call__(PyObject other) { + update(getJavaSet(self, "|=", other)); + return self; + } + }; + + private static class CopyMethod extends SetMethod { + protected CopyMethod(String name) { + super(name, 0); + } + @Override + public PyObject __call__() { + return makePySet(asSet()); + } + } + + private static final SetMethod deepcopyOverrideProxy = new SetMethod("__deepcopy__", 1) { + @Override + public PyObject __call__(PyObject memo) { + Set newSet = new HashSet<>(); + for (Object obj : asSet()) { + PyObject pyobj = Py.java2py(obj); + PyObject newobj = pyobj.invoke("__deepcopy__", memo); + newSet.add(newobj.__tojava__(Object.class)); + } + return makePySet(newSet); + } + }; + + private static final SetMethod reduceProxy = new SetMethod("__reduce__", 0) { + @Override + public PyObject __call__() { + PyObject args = new PyTuple(new PyList(new JavaIterator(asSet()))); + PyObject dict = __findattr__("__dict__"); + if (dict == null) { + dict = Py.None; + } + return new PyTuple(self.getType(), args, dict); + } + }; + + private static final SetMethod containsProxy = new SetMethod("__contains__", 1) { + @Override + public PyObject __call__(PyObject value) { + return Py.newBoolean(asSet().contains(value.__tojava__(Object.class))); + } + }; + private static final SetMethod hashProxy = new SetMethod("__hash__", 0) { + // in general, we don't know if this is really true or not + @Override + public PyObject __call__(PyObject value) { + throw Py.TypeError(String.format("unhashable type: '%.200s'", self.getType().fastGetName())); + } + }; + + private static final SetMethod discardProxy = new SetMethod("discard", 1) { + @Override + public PyObject __call__(PyObject value) { + asSet().remove(value.__tojava__(Object.class)); + return Py.None; + } + }; + private static final SetMethod popProxy = new SetMethod("pop", 0) { + @Override + public PyObject __call__() { + Set selfSet = asSet(); + Iterator it; + if (selfSet instanceof NavigableSet) { + it = ((NavigableSet) selfSet).descendingIterator(); + } else { + it = selfSet.iterator(); + } + try { + PyObject value = Py.java2py(it.next()); + it.remove(); + return value; + } catch (NoSuchElementException ex) { + throw Py.KeyError("pop from an empty set"); + } + } + }; + private static final SetMethod removeOverrideProxy = new SetMethod("remove", 1) { + @Override + public PyObject __call__(PyObject value) { + boolean removed = asSet().remove(value.__tojava__(Object.class)); + if (!removed) { + throw Py.KeyError(value); + } return Py.None; } }; static PyBuiltinMethod[] getProxyMethods() { - return new PyBuiltinMethod[]{}; + return new PyBuiltinMethod[]{ + cmpProxy, + eqProxy, + ltProxy, + new IsSubsetMethod("__le__"), + new IsSubsetMethod("issubset"), + new IsSupersetMethod("__ge__"), + new IsSupersetMethod("issuperset"), + gtProxy, + isDisjointProxy, + + differenceProxy, + differenceUpdateProxy, + subProxy, + isubProxy, + + intersectionProxy, + intersectionUpdateProxy, + andProxy, + iandProxy, + + symDiffProxy, + symDiffUpdateProxy, + xorProxy, + ixorProxy, + + unionProxy, + updateProxy, + orProxy, + iorProxy, + + new CopyMethod("copy"), + new CopyMethod("__copy__"), + reduceProxy, + + containsProxy, + hashProxy, + + discardProxy, + popProxy + }; } static PyBuiltinMethod[] getPostProxyMethods() { - return new PyBuiltinMethod[]{}; + return new PyBuiltinMethod[]{ + deepcopyOverrideProxy, + removeOverrideProxy + }; } } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 31 02:41:07 2014 From: jython-checkins at python.org (jeff.allen) Date: Wed, 31 Dec 2014 01:41:07 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Rationalise_certain_constan?= =?utf-8?q?ts_in_PyFloat=2E?= Message-ID: <20141231014104.125163.2951@psf.io> https://hg.python.org/jython/rev/fce295e64d78 changeset: 7481:fce295e64d78 user: Jeff Allen date: Thu Dec 18 23:12:45 2014 +0000 summary: Rationalise certain constants in PyFloat. files: src/org/python/core/PyFloat.java | 22 ++++++++++++------- src/org/python/core/PyLong.java | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/org/python/core/PyFloat.java b/src/org/python/core/PyFloat.java --- a/src/org/python/core/PyFloat.java +++ b/src/org/python/core/PyFloat.java @@ -29,6 +29,12 @@ static final Spec SPEC_REPR = InternalFormat.fromText(" >r"); /** Format specification used by str(). */ static final Spec SPEC_STR = Spec.NUMERIC; + /** Constant float(0). */ + static final PyFloat ZERO = new PyFloat(0.0); + /** Constant float(1). */ + static final PyFloat ONE = new PyFloat(1.0); + /** Constant float("nan"). */ + static final PyFloat NAN = new PyFloat(Double.NaN); private final double value; @@ -56,7 +62,7 @@ PyObject x = ap.getPyObject(0, null); if (x == null) { if (new_.for_type == subtype) { - return new PyFloat(0.0); + return ZERO; } else { return new PyFloatDerived(subtype, 0.0); } @@ -90,7 +96,7 @@ @ExposedGet(name = "imag", doc = BuiltinDocs.float_imag_doc) public PyObject getImag() { - return Py.newFloat(0.0); + return ZERO; } @ExposedClassMethod(doc = BuiltinDocs.float_fromhex_doc) @@ -108,7 +114,7 @@ if (value.length() == 0) { throw Py.ValueError(message); } else if (value.equals("nan") || value.equals("-nan") || value.equals("+nan")) { - return new PyFloat(Double.NaN); + return NAN; } else if (value.equals("inf") || value.equals("infinity") || value.equals("+inf") || value.equals("+infinity")) { return new PyFloat(Double.POSITIVE_INFINITY); @@ -760,18 +766,18 @@ */ if (w == 0) { // v**0 is 1, even 0**0 - return new PyFloat(1.0); + return ONE; } else if (Double.isNaN(v)) { // nan**w = nan, unless w == 0 - return new PyFloat(Double.NaN); + return NAN; } else if (Double.isNaN(w)) { // v**nan = nan, unless v == 1; 1**nan = 1 if (v == 1.0) { - return new PyFloat(1.0); + return ONE; } else { - return new PyFloat(Double.NaN); + return NAN; } } else if (Double.isInfinite(w)) { @@ -780,7 +786,7 @@ * Python they are all 1. */ if (v == 1.0 || v == -1.0) { - return new PyFloat(1.0); + return ONE; } } else if (v == 0.0) { diff --git a/src/org/python/core/PyLong.java b/src/org/python/core/PyLong.java --- a/src/org/python/core/PyLong.java +++ b/src/org/python/core/PyLong.java @@ -545,7 +545,7 @@ if (aexp > Integer.MAX_VALUE / 8) { throw Py.OverflowError("long/long too large for a float"); } else if (aexp < -(Integer.MAX_VALUE / 8)) { - return new PyFloat(0.0); + return PyFloat.ZERO; } ad = ad * Math.pow(2.0, aexp * 8); -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 31 02:41:07 2014 From: jython-checkins at python.org (jeff.allen) Date: Wed, 31 Dec 2014 01:41:07 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Distinguish_constants_0j_an?= =?utf-8?q?d_-0j_when_generating_code=2E?= Message-ID: <20141231014105.125161.9713@psf.io> https://hg.python.org/jython/rev/43b491fcfe98 changeset: 7484:43b491fcfe98 user: Jeff Allen date: Sat Dec 20 23:51:33 2014 +0000 summary: Distinguish constants 0j and -0j when generating code. Fixes a failure in test_complex: -0.0 and 0.0 are not the same constant. files: Lib/test/test_complex.py | 2 -- src/org/python/compiler/Module.java | 16 ++++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py --- a/Lib/test/test_complex.py +++ b/Lib/test/test_complex.py @@ -501,8 +501,6 @@ self.assertEqual(complex(INF, 0).__getnewargs__(), (INF, 0.0)) if float.__getformat__("double").startswith("IEEE"): - @unittest.skipIf(test_support.is_jython, - "FIXME: not working in Jython") def test_plus_minus_0j(self): # test that -0j and 0j literals are not identified z1, z2 = 0j, -0j diff --git a/src/org/python/compiler/Module.java b/src/org/python/compiler/Module.java --- a/src/org/python/compiler/Module.java +++ b/src/org/python/compiler/Module.java @@ -101,14 +101,12 @@ @Override public boolean equals(Object o) { if (o instanceof PyFloatConstant) { - double oVal = ((PyFloatConstant)o).value; - if (ZERO == value) { - // math.copysign() needs to distinguish signs of zeroes - return oVal == value && Double.toString(oVal).equals(Double.toString(value)); - } - return oVal == value; + // Ensure hashtable works things like for -0.0 and NaN (see java.lang.Double.equals). + PyFloatConstant pyco = (PyFloatConstant)o; + return Double.doubleToLongBits(pyco.value) == Double.doubleToLongBits(value); + } else { + return false; } - return false; } } @@ -138,7 +136,9 @@ @Override public boolean equals(Object o) { if (o instanceof PyComplexConstant) { - return ((PyComplexConstant)o).value == value; + // Ensure hashtable works things like for -0.0 and NaN (see java.lang.Double.equals). + PyComplexConstant pyco = (PyComplexConstant)o; + return Double.doubleToLongBits(pyco.value) == Double.doubleToLongBits(value); } else { return false; } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 31 02:41:07 2014 From: jython-checkins at python.org (jeff.allen) Date: Wed, 31 Dec 2014 01:41:07 +0000 Subject: [Jython-checkins] =?utf-8?b?anl0aG9uOiBGaXggZmxvYXQuX19tb2RfXyB0?= =?utf-8?q?o_conform_to_spec=2E?= Message-ID: <20141231014103.102375.54017@psf.io> https://hg.python.org/jython/rev/2472ea52bc76 changeset: 7479:2472ea52bc76 user: Jeff Allen date: Wed Dec 17 21:24:56 2014 +0000 summary: Fix float.__mod__ to conform to spec. Fixes test failures in test_float by re-implementing using Java % and tweaks. files: Lib/test/test_float.py | 9 ++--- src/org/python/core/PyFloat.java | 32 +++++++++++++++----- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py --- a/Lib/test/test_float.py +++ b/Lib/test/test_float.py @@ -217,10 +217,9 @@ # In particular, check signs of zeros. mod = operator.mod - #FIXME: Jython fails some mod edge cases. - #self.assertEqualAndEqualSign(mod(-1.0, 1.0), 0.0) + self.assertEqualAndEqualSign(mod(-1.0, 1.0), 0.0) self.assertEqualAndEqualSign(mod(-1e-100, 1.0), 1.0) - #self.assertEqualAndEqualSign(mod(-0.0, 1.0), 0.0) + self.assertEqualAndEqualSign(mod(-0.0, 1.0), 0.0) self.assertEqualAndEqualSign(mod(0.0, 1.0), 0.0) self.assertEqualAndEqualSign(mod(1e-100, 1.0), 1e-100) self.assertEqualAndEqualSign(mod(1.0, 1.0), 0.0) @@ -228,9 +227,9 @@ self.assertEqualAndEqualSign(mod(-1.0, -1.0), -0.0) self.assertEqualAndEqualSign(mod(-1e-100, -1.0), -1e-100) self.assertEqualAndEqualSign(mod(-0.0, -1.0), -0.0) - #self.assertEqualAndEqualSign(mod(0.0, -1.0), -0.0) + self.assertEqualAndEqualSign(mod(0.0, -1.0), -0.0) self.assertEqualAndEqualSign(mod(1e-100, -1.0), -1.0) - #self.assertEqualAndEqualSign(mod(1.0, -1.0), -0.0) + self.assertEqualAndEqualSign(mod(1.0, -1.0), -0.0) @requires_IEEE_754 def test_float_pow(self): diff --git a/src/org/python/core/PyFloat.java b/src/org/python/core/PyFloat.java --- a/src/org/python/core/PyFloat.java +++ b/src/org/python/core/PyFloat.java @@ -625,15 +625,30 @@ return new PyFloat(leftv / getValue()); } + /** + * Python % operator: y = n*x + z. The modulo operator always yields a result with the same sign + * as its second operand (or zero). (Compare java.Math.IEEEremainder) + * + * @param x dividend + * @param y divisor + * @return x % y + */ private static double modulo(double x, double y) { - if (y == 0) { + if (y == 0.0) { throw Py.ZeroDivisionError("float modulo"); + } else { + double z = x % y; + if (z == 0.0) { + // Has to be same sign as y (even when zero). + return Math.copySign(z, y); + } else if ((z > 0.0) == (y > 0.0)) { + // z has same sign as y, as it must. + return z; + } else { + // Note abs(z) < abs(y) and opposite sign. + return z + y; + } } - double z = Math.IEEEremainder(x, y); - if (z * y < 0) { - z += y; - } - return z; } @Override @@ -934,8 +949,9 @@ } /** - * Common code for PyFloat, {@link PyInteger} and {@link PyLong} to prepare a {@link FloatFormatter} from a parsed specification. - * The object returned has format method {@link FloatFormatter#format(double)}. + * Common code for PyFloat, {@link PyInteger} and {@link PyLong} to prepare a + * {@link FloatFormatter} from a parsed specification. The object returned has format method + * {@link FloatFormatter#format(double)}. * * @param spec a parsed PEP-3101 format specification. * @return a formatter ready to use, or null if the type is not a floating point format type. -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 31 02:41:07 2014 From: jython-checkins at python.org (jeff.allen) Date: Wed, 31 Dec 2014 01:41:07 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Prevent_String=2Eatof=28=29?= =?utf-8?q?_from_accepting_hex_notation=2E?= Message-ID: <20141231014103.22188.43496@psf.io> https://hg.python.org/jython/rev/f125a1eeacbc changeset: 7478:f125a1eeacbc user: Jeff Allen date: Wed Dec 17 08:27:24 2014 +0000 summary: Prevent String.atof() from accepting hex notation. Fixes a test failure in test_float: float("0x3.p-1") to raise ValueError. files: Lib/test/test_float.py | 5 +- src/org/python/core/PyString.java | 106 ++++++++++------- 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py --- a/Lib/test/test_float.py +++ b/Lib/test/test_float.py @@ -35,9 +35,8 @@ self.assertEqual(float(" 3.14 "), 3.14) self.assertRaises(ValueError, float, " 0x3.1 ") - #FIXME: not raising ValueError on Jython: - #self.assertRaises(ValueError, float, " -0x3.p-1 ") - #self.assertRaises(ValueError, float, " +0x3.p-1 ") + self.assertRaises(ValueError, float, " -0x3.p-1 ") + self.assertRaises(ValueError, float, " +0x3.p-1 ") self.assertRaises(ValueError, float, "++3.14") self.assertRaises(ValueError, float, "+-3.14") diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -4,6 +4,8 @@ import java.lang.ref.Reference; import java.lang.ref.SoftReference; import java.math.BigInteger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.python.core.buffer.BaseBuffer; import org.python.core.buffer.SimpleStringBuffer; @@ -2710,55 +2712,75 @@ } } + /** + * Convert this PyString to a floating-point value according to Python rules. + * + * @return the value + */ public double atof() { - StringBuilder s = null; - int n = getString().length(); - for (int i = 0; i < n; i++) { - char ch = getString().charAt(i); - if (ch == '\u0000') { - throw Py.ValueError("null byte in argument for float()"); + String bogus = null; + double x = 0.0; + Matcher m = atofPattern.matcher(getString()); + + if (m.matches()) { + // Might be a valid float + try { + if (m.group(3) == null) { + // No numeric part was found: it's something like "-Inf" or "hOrsE" + x = atofSpecials(m.group(1)); + } else { + // A numeric part was present, try to convert the whole + x = Double.parseDouble(m.group(1)); + } + } catch (NumberFormatException e) { + bogus = m.group(1); } - if (Character.isDigit(ch)) { - if (s == null) { - s = new StringBuilder(getString()); - } - int val = Character.digit(ch, 10); - s.setCharAt(i, Character.forDigit(val, 10)); - } + } else { + // This doesn't match the pattern for a float value + bogus = getString().trim(); } - String sval = getString(); - if (s != null) { - sval = s.toString(); + + // At this point, bogus will have been set to the trimmed string if there was a problem. + if (bogus == null) { + return x; + } else { + String fmt = "could not convert string to float: %s"; + throw Py.ValueError(String.format(fmt, bogus)); } - try { - // Double.valueOf allows format specifier ("d" or "f") at the end - String lowSval = sval.toLowerCase(); - if (lowSval.equals("nan")) { + + } + + /** + * Regular expression that includes all valid a Python float() arguments, in which group 1 + * captures the whole, stripped of white space, and group 3 will be present only if the form is + * numeric. Invalid non numerics are accepted ("+hOrsE" as "-inf"). + */ + private static Pattern atofPattern = Pattern + .compile("\\s*([+-]?(((\\d+(\\.\\d*)?|\\.\\d+)([eE][+-]?\\d+)?)|\\p{Alpha}+))\\s*"); + + /** + * Conversion for non-numeric floats, accepting signed or unsigned "inf" and "nan", in any case. + * + * @param s to convert + * @return non-numeric result (if valid) + * @throws NumberFormatException if not a valid non-numeric indicator + */ + private static double atofSpecials(String s) throws NumberFormatException { + switch (s.toLowerCase()) { + case "nan": + case "+nan": + case "-nan": return Double.NaN; - } else if (lowSval.equals("+nan")) { - return Double.NaN; - } else if (lowSval.equals("-nan")) { - return Double.NaN; - } else if (lowSval.equals("inf")) { + case "inf": + case "+inf": + case "infinity": + case "+infinity": return Double.POSITIVE_INFINITY; - } else if (lowSval.equals("+inf")) { - return Double.POSITIVE_INFINITY; - } else if (lowSval.equals("-inf")) { + case "-inf": + case "-infinity": return Double.NEGATIVE_INFINITY; - } else if (lowSval.equals("infinity")) { - return Double.POSITIVE_INFINITY; - } else if (lowSval.equals("+infinity")) { - return Double.POSITIVE_INFINITY; - } else if (lowSval.equals("-infinity")) { - return Double.NEGATIVE_INFINITY; - } - - if (lowSval.endsWith("d") || lowSval.endsWith("f")) { - throw new NumberFormatException("format specifiers not allowed"); - } - return Double.valueOf(sval).doubleValue(); - } catch (NumberFormatException exc) { - throw Py.ValueError("invalid literal for __float__: " + getString()); + default: + throw new NumberFormatException(); } } -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 31 02:41:06 2014 From: jython-checkins at python.org (jeff.allen) Date: Wed, 31 Dec 2014 01:41:06 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Formatting_changes_only=2C_?= =?utf-8?q?ahead_of_fixes_to_compiler/Module=2Ejava?= Message-ID: <20141231014104.71905.86826@psf.io> https://hg.python.org/jython/rev/1010be768075 changeset: 7483:1010be768075 user: Jeff Allen date: Sat Dec 20 21:39:38 2014 +0000 summary: Formatting changes only, ahead of fixes to compiler/Module.java files: src/org/python/compiler/Module.java | 198 ++++++++------- 1 files changed, 107 insertions(+), 91 deletions(-) diff --git a/src/org/python/compiler/Module.java b/src/org/python/compiler/Module.java --- a/src/org/python/compiler/Module.java +++ b/src/org/python/compiler/Module.java @@ -1,6 +1,10 @@ // Copyright (c) Corporation for National Research Initiatives package org.python.compiler; +import static org.python.util.CodegenUtils.ci; +import static org.python.util.CodegenUtils.p; +import static org.python.util.CodegenUtils.sig; + import java.io.IOException; import java.io.OutputStream; import java.util.ArrayList; @@ -10,12 +14,17 @@ import org.objectweb.asm.Label; import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; +import org.python.antlr.ParseException; +import org.python.antlr.PythonTree; import org.python.antlr.ast.Num; +import org.python.antlr.ast.Str; +import org.python.antlr.ast.Suite; +import org.python.antlr.base.mod; import org.python.core.CodeBootstrap; import org.python.core.CodeFlag; import org.python.core.CodeLoader; import org.python.core.CompilerFlags; -import org.python.core.ThreadState; import org.python.core.Py; import org.python.core.PyCode; import org.python.core.PyComplex; @@ -30,13 +39,7 @@ import org.python.core.PyRunnableBootstrap; import org.python.core.PyString; import org.python.core.PyUnicode; -import org.objectweb.asm.Type; -import org.python.antlr.ParseException; -import org.python.antlr.PythonTree; -import org.python.antlr.ast.Str; -import org.python.antlr.ast.Suite; -import org.python.antlr.base.mod; -import static org.python.util.CodegenUtils.*; +import org.python.core.ThreadState; class PyIntegerConstant extends Constant implements ClassConstants, Opcodes { @@ -46,13 +49,14 @@ this.value = value; } + @Override void get(Code c) throws IOException { c.iconst(value); // it would be nice if we knew we didn't have to box next c.invokestatic(p(Py.class), "newInteger", sig(PyInteger.class, Integer.TYPE)); } - void put(Code c) throws IOException { - } + @Override + void put(Code c) throws IOException {} @Override public int hashCode() { @@ -62,14 +66,16 @@ @Override public boolean equals(Object o) { if (o instanceof PyIntegerConstant) { - return ((PyIntegerConstant) o).value == value; + return ((PyIntegerConstant)o).value == value; } else { return false; } } } + class PyFloatConstant extends Constant implements ClassConstants, Opcodes { + private static final double ZERO = 0.0; final double value; @@ -78,17 +84,18 @@ this.value = value; } + @Override void get(Code c) throws IOException { c.ldc(new Double(value)); c.invokestatic(p(Py.class), "newFloat", sig(PyFloat.class, Double.TYPE)); } - void put(Code c) throws IOException { - } + @Override + void put(Code c) throws IOException {} @Override public int hashCode() { - return (int) value; + return (int)value; } @Override @@ -105,6 +112,7 @@ } } + class PyComplexConstant extends Constant implements ClassConstants, Opcodes { final double value; @@ -113,29 +121,31 @@ this.value = value; } + @Override void get(Code c) throws IOException { c.ldc(new Double(value)); c.invokestatic(p(Py.class), "newImaginary", sig(PyComplex.class, Double.TYPE)); } - void put(Code c) throws IOException { - } + @Override + void put(Code c) throws IOException {} @Override public int hashCode() { - return (int) value; + return (int)value; } @Override public boolean equals(Object o) { if (o instanceof PyComplexConstant) { - return ((PyComplexConstant) o).value == value; + return ((PyComplexConstant)o).value == value; } else { return false; } } } + class PyStringConstant extends Constant implements ClassConstants, Opcodes { final String value; @@ -144,13 +154,14 @@ this.value = value; } + @Override void get(Code c) throws IOException { c.ldc(value); c.invokestatic(p(PyString.class), "fromInterned", sig(PyString.class, String.class)); } - void put(Code c) throws IOException { - } + @Override + void put(Code c) throws IOException {} @Override public int hashCode() { @@ -160,13 +171,14 @@ @Override public boolean equals(Object o) { if (o instanceof PyStringConstant) { - return ((PyStringConstant) o).value.equals(value); + return ((PyStringConstant)o).value.equals(value); } else { return false; } } } + class PyUnicodeConstant extends Constant implements ClassConstants, Opcodes { final String value; @@ -175,13 +187,14 @@ this.value = value; } + @Override void get(Code c) throws IOException { c.ldc(value); c.invokestatic(p(PyUnicode.class), "fromInterned", sig(PyUnicode.class, String.class)); } - void put(Code c) throws IOException { - } + @Override + void put(Code c) throws IOException {} @Override public int hashCode() { @@ -191,13 +204,14 @@ @Override public boolean equals(Object o) { if (o instanceof PyUnicodeConstant) { - return ((PyUnicodeConstant) o).value.equals(value); + return ((PyUnicodeConstant)o).value.equals(value); } else { return false; } } } + class PyLongConstant extends Constant implements ClassConstants, Opcodes { final String value; @@ -206,14 +220,14 @@ this.value = value; } + @Override void get(Code c) throws IOException { c.ldc(value); c.invokestatic(p(Py.class), "newLong", sig(PyLong.class, String.class)); - } - void put(Code c) throws IOException { - } + @Override + void put(Code c) throws IOException {} @Override public int hashCode() { @@ -223,13 +237,14 @@ @Override public boolean equals(Object o) { if (o instanceof PyLongConstant) { - return ((PyLongConstant) o).value.equals(value); + return ((PyLongConstant)o).value.equals(value); } else { return false; } } } + class PyCodeConstant extends Constant implements ClassConstants, Opcodes { final String co_name; @@ -247,14 +262,13 @@ PyCodeConstant(mod tree, String name, boolean fast_locals, String className, boolean classBody, boolean printResults, int firstlineno, ScopeInfo scope, CompilerFlags cflags, - Module module) - throws Exception { + Module module) throws Exception { this.co_name = name; this.co_firstlineno = firstlineno; this.module = module; - //Needed so that moreflags can be final. + // Needed so that moreflags can be final. int _moreflags = 0; if (scope.ac != null) { @@ -262,11 +276,11 @@ keywordlist = scope.ac.keywordlist; argcount = scope.ac.names.size(); - //Do something to add init_code to tree - //XXX: not sure we should be modifying scope.ac in a PyCodeConstant - //constructor. + // Do something to add init_code to tree + // XXX: not sure we should be modifying scope.ac in a PyCodeConstant + // constructor. if (scope.ac.init_code.size() > 0) { - scope.ac.appendInitCode((Suite) tree); + scope.ac.appendInitCode((Suite)tree); } } else { arglist = false; @@ -276,13 +290,13 @@ id = module.codes.size(); - //Better names in the future? + // Better names in the future? if (isJavaIdentifier(name)) { fname = name + "$" + id; } else { fname = "f$" + id; } - //XXX: is fname needed at all, or should we just use "name"? + // XXX: is fname needed at all, or should we just use "name"? this.name = fname; // !classdef only @@ -313,7 +327,7 @@ moreflags = _moreflags; } - //XXX: this can probably go away now that we can probably just copy the list. + // XXX: this can probably go away now that we can probably just copy the list. private List toNameAr(List names, boolean nullok) { int sz = names.size(); if (sz == 0 && nullok) { @@ -341,15 +355,17 @@ return true; } + @Override void get(Code c) throws IOException { c.getstatic(module.classfile.name, name, ci(PyCode.class)); } + @Override void put(Code c) throws IOException { module.classfile.addField(name, ci(PyCode.class), access); c.iconst(argcount); - //Make all names + // Make all names int nameArray; if (names != null) { nameArray = CodeCompiler.makeStrings(c, names); @@ -388,14 +404,17 @@ c.iconst(moreflags); - c.invokestatic(p(Py.class), "newCode", sig(PyCode.class, Integer.TYPE, - String[].class, String.class, String.class, Integer.TYPE, Boolean.TYPE, - Boolean.TYPE, PyFunctionTable.class, Integer.TYPE, String[].class, - String[].class, Integer.TYPE, Integer.TYPE)); + c.invokestatic( + p(Py.class), + "newCode", + sig(PyCode.class, Integer.TYPE, String[].class, String.class, String.class, + Integer.TYPE, Boolean.TYPE, Boolean.TYPE, PyFunctionTable.class, + Integer.TYPE, String[].class, String[].class, Integer.TYPE, Integer.TYPE)); c.putstatic(module.classfile.name, name, ci(PyCode.class)); } } + public class Module implements Opcodes, ClassConstants, CompilationContext { ClassFile classfile; @@ -421,8 +440,8 @@ public Module(String name, String filename, boolean linenumbers, long mtime) { this.linenumbers = linenumbers; this.mtime = mtime; - classfile = new ClassFile(name, p(PyFunctionTable.class), - ACC_SYNCHRONIZED | ACC_PUBLIC, mtime); + classfile = + new ClassFile(name, p(PyFunctionTable.class), ACC_SYNCHRONIZED | ACC_PUBLIC, mtime); constants = new Hashtable(); sfilename = filename; if (filename != null) { @@ -446,7 +465,7 @@ } ret = c; c.module = this; - //More sophisticated name mappings might be nice + // More sophisticated name mappings might be nice c.name = "_" + constants.size(); constants.put(ret, ret); return ret; @@ -477,27 +496,23 @@ } PyCodeConstant codeConstant(mod tree, String name, boolean fast_locals, String className, - boolean classBody, boolean printResults, int firstlineno, - ScopeInfo scope, CompilerFlags cflags) throws Exception { + boolean classBody, boolean printResults, int firstlineno, ScopeInfo scope, + CompilerFlags cflags) throws Exception { return codeConstant(tree, name, fast_locals, className, null, classBody, printResults, - firstlineno, scope, cflags); + firstlineno, scope, cflags); } PyCodeConstant codeConstant(mod tree, String name, boolean fast_locals, String className, - Str classDoc, boolean classBody, boolean printResults, - int firstlineno, ScopeInfo scope, CompilerFlags cflags) - throws Exception { - PyCodeConstant code = new PyCodeConstant(tree, name, fast_locals, - className, classBody, printResults, firstlineno, scope, cflags, - this); + Str classDoc, boolean classBody, boolean printResults, int firstlineno, + ScopeInfo scope, CompilerFlags cflags) throws Exception { + PyCodeConstant code = new PyCodeConstant(tree, name, fast_locals, className, classBody, // + printResults, firstlineno, scope, cflags, this); codes.add(code); CodeCompiler compiler = new CodeCompiler(this, printResults); - Code c = classfile.addMethod( - code.fname, - sig(PyObject.class, PyFrame.class, ThreadState.class), - ACC_PUBLIC); + Code c = classfile.addMethod(code.fname, // + sig(PyObject.class, PyFrame.class, ThreadState.class), ACC_PUBLIC); compiler.parse(tree, c, fast_locals, className, classDoc, classBody, scope, cflags); return code; @@ -518,8 +533,8 @@ } public void addMain() throws IOException { - Code c = classfile.addMethod("main", sig(Void.TYPE, String[].class), - ACC_PUBLIC | ACC_STATIC); + Code c = classfile.addMethod("main", // + sig(Void.TYPE, String[].class), ACC_PUBLIC | ACC_STATIC); c.new_(classfile.name); c.dup(); c.ldc(classfile.name); @@ -533,8 +548,8 @@ } public void addBootstrap() throws IOException { - Code c = classfile.addMethod(CodeLoader.GET_BOOTSTRAP_METHOD_NAME, sig(CodeBootstrap.class), - ACC_PUBLIC | ACC_STATIC); + Code c = classfile.addMethod(CodeLoader.GET_BOOTSTRAP_METHOD_NAME, // + sig(CodeBootstrap.class), ACC_PUBLIC | ACC_STATIC); c.ldc(Type.getType("L" + classfile.name + ";")); c.invokestatic(p(PyRunnableBootstrap.class), PyRunnableBootstrap.REFLECTION_METHOD_NAME, sig(CodeBootstrap.class, Class.class)); @@ -548,7 +563,7 @@ Enumeration e = constants.elements(); while (e.hasMoreElements()) { - Constant constant = (Constant) e.nextElement(); + Constant constant = (Constant)e.nextElement(); constant.put(c); } @@ -561,7 +576,7 @@ } public void addFunctions() throws IOException { - Code code = classfile.addMethod("call_function", + Code code = classfile.addMethod("call_function", // sig(PyObject.class, Integer.TYPE, PyFrame.class, ThreadState.class), ACC_PUBLIC); code.aload(0); // this @@ -574,18 +589,18 @@ labels[i] = new Label(); } - //Get index for function to call + // Get index for function to call code.iload(1); code.tableswitch(0, labels.length - 1, def, labels); for (i = 0; i < labels.length; i++) { code.label(labels[i]); - code.invokevirtual(classfile.name, (codes.get(i)).fname, sig(PyObject.class, - PyFrame.class, ThreadState.class)); + code.invokevirtual(classfile.name, (codes.get(i)).fname, + sig(PyObject.class, PyFrame.class, ThreadState.class)); code.areturn(); } code.label(def); - //Should probably throw internal exception here + // Should probably throw internal exception here code.aconst_null(); code.areturn(); } @@ -606,20 +621,23 @@ } // Implementation of CompilationContext + @Override public Future getFutures() { return futures; } + @Override public String getFilename() { return sfilename; } + @Override public ScopeInfo getScopeInfo(PythonTree node) { return scopes.get(node); } - public void error(String msg, boolean err, PythonTree node) - throws Exception { + @Override + public void error(String msg, boolean err, PythonTree node) throws Exception { if (!err) { try { Py.warning(Py.SyntaxWarning, msg, (sfilename != null) ? sfilename : "?", @@ -635,8 +653,7 @@ } public static void compile(mod node, OutputStream ostream, String name, String filename, - boolean linenumbers, boolean printResults, CompilerFlags cflags) - throws Exception { + boolean linenumbers, boolean printResults, CompilerFlags cflags) throws Exception { compile(node, ostream, name, filename, linenumbers, printResults, cflags, org.python.core.imp.NO_MTIME); } @@ -651,30 +668,28 @@ module.futures.preprocessFutures(node, cflags); new ScopesCompiler(module, module.scopes).parse(node); - //Add __doc__ if it exists + // Add __doc__ if it exists - Constant main = module.codeConstant(node, "", false, null, false, - printResults, 0, - module.getScopeInfo(node), - cflags); + Constant main = module.codeConstant(node, "", false, null, false, // + printResults, 0, module.getScopeInfo(node), cflags); module.mainCode = main; module.write(ostream); } public void emitNum(Num node, Code code) throws Exception { if (node.getInternalN() instanceof PyInteger) { - integerConstant(((PyInteger) node.getInternalN()).getValue()).get(code); + integerConstant(((PyInteger)node.getInternalN()).getValue()).get(code); } else if (node.getInternalN() instanceof PyLong) { - longConstant(((PyObject) node.getInternalN()).__str__().toString()).get(code); + longConstant(((PyObject)node.getInternalN()).__str__().toString()).get(code); } else if (node.getInternalN() instanceof PyFloat) { - floatConstant(((PyFloat) node.getInternalN()).getValue()).get(code); + floatConstant(((PyFloat)node.getInternalN()).getValue()).get(code); } else if (node.getInternalN() instanceof PyComplex) { - complexConstant(((PyComplex) node.getInternalN()).imag).get(code); + complexConstant(((PyComplex)node.getInternalN()).imag).get(code); } } public void emitStr(Str node, Code code) throws Exception { - PyString s = (PyString) node.getInternalS(); + PyString s = (PyString)node.getInternalS(); if (s instanceof PyUnicode) { unicodeConstant(s.asString()).get(code); } else { @@ -682,7 +697,8 @@ } } - public boolean emitPrimitiveArraySetters(java.util.List nodes, Code code) throws Exception { + public boolean emitPrimitiveArraySetters(java.util.List nodes, Code code) + throws Exception { final int n = nodes.size(); if (n < USE_SETTERS_LIMIT) { return false; // Too small to matter, so bail @@ -704,25 +720,25 @@ code.iconst(n); code.anewarray(p(PyObject.class)); for (int i = 0; i < num_setters; i++) { - Code setter = this.classfile.addMethod( - "set$$" + setter_count, sig(Void.TYPE, PyObject[].class), ACC_STATIC | ACC_PRIVATE); + Code setter = this.classfile.addMethod("set$$" + setter_count, // + sig(Void.TYPE, PyObject[].class), ACC_STATIC | ACC_PRIVATE); - for (int j = 0; (j < MAX_SETTINGS_PER_SETTER) && ((i * MAX_SETTINGS_PER_SETTER + j) < n); j++) { + for (int j = 0; (j < MAX_SETTINGS_PER_SETTER) + && ((i * MAX_SETTINGS_PER_SETTER + j) < n); j++) { setter.aload(0); setter.iconst(i * MAX_SETTINGS_PER_SETTER + j); PythonTree node = nodes.get(i * MAX_SETTINGS_PER_SETTER + j); if (node instanceof Num) { emitNum((Num)node, setter); - } - else if (node instanceof Str) { + } else if (node instanceof Str) { emitStr((Str)node, setter); } setter.aastore(); } setter.return_(); code.dup(); - code.invokestatic( - this.classfile.name, "set$$" + setter_count, sig(Void.TYPE, PyObject[].class)); + code.invokestatic(this.classfile.name, "set$$" + setter_count, + sig(Void.TYPE, PyObject[].class)); setter_count++; } return true; -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 31 02:41:06 2014 From: jython-checkins at python.org (jeff.allen) Date: Wed, 31 Dec 2014 01:41:06 +0000 Subject: [Jython-checkins] =?utf-8?b?anl0aG9uOiBGaXggc3RyLl9fY29tcGxleF9f?= =?utf-8?q?=28=29_to_accept_all_valid_inputs=2E?= Message-ID: <20141231014104.120047.45144@psf.io> https://hg.python.org/jython/rev/5409ad3e93f4 changeset: 7482:5409ad3e93f4 user: Jeff Allen date: Sat Dec 20 09:51:58 2014 +0000 summary: Fix str.__complex__() to accept all valid inputs. Fixes a number of test failures in test_complex. files: Lib/test/test_complex.py | 220 ------------- src/org/python/core/PyString.java | 302 ++++++++--------- 2 files changed, 149 insertions(+), 373 deletions(-) diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py --- a/Lib/test/test_complex.py +++ b/Lib/test/test_complex.py @@ -220,7 +220,6 @@ def test_conjugate(self): self.assertClose(complex(5.3, 9.8).conjugate(), 5.3-9.8j) - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") def test_constructor(self): class OS: def __init__(self, value): self.value = value @@ -264,24 +263,16 @@ self.assertAlmostEqual(complex(), 0) self.assertAlmostEqual(complex("-1"), -1) self.assertAlmostEqual(complex("+1"), +1) - #FIXME: these are not working in Jython. self.assertAlmostEqual(complex("(1+2j)"), 1+2j) self.assertAlmostEqual(complex("(1.3+2.2j)"), 1.3+2.2j) - # ] self.assertAlmostEqual(complex("3.14+1J"), 3.14+1j) - #FIXME: these are not working in Jython. self.assertAlmostEqual(complex(" ( +3.14-6J )"), 3.14-6j) self.assertAlmostEqual(complex(" ( +3.14-J )"), 3.14-1j) self.assertAlmostEqual(complex(" ( +3.14+j )"), 3.14+1j) - # ] self.assertAlmostEqual(complex("J"), 1j) - #FIXME: this is not working in Jython. self.assertAlmostEqual(complex("( j )"), 1j) - # ] self.assertAlmostEqual(complex("+J"), 1j) - #FIXME: this is not working in Jython. self.assertAlmostEqual(complex("( -j)"), -1j) - # ] self.assertAlmostEqual(complex('1e-500'), 0.0 + 0.0j) self.assertAlmostEqual(complex('-1e-500j'), 0.0 - 0.0j) self.assertAlmostEqual(complex('-1e-500+1e-500j'), -0.0 + 0.0j) @@ -301,9 +292,7 @@ return atan2(x, -1.) self.assertEqual(split_zeros(complex(1., 0.).imag), split_zeros(0.)) - #FIXME: this is not working in Jython. self.assertEqual(split_zeros(complex(1., -0.).imag), split_zeros(-0.)) - # ] self.assertEqual(split_zeros(complex(0., 1.).real), split_zeros(0.)) self.assertEqual(split_zeros(complex(-0., 1.).real), split_zeros(-0.)) @@ -340,199 +329,19 @@ self.assertRaises(ValueError, complex, "(1+2j)123") if test_support.have_unicode: self.assertRaises(ValueError, complex, unicode("x")) - #FIXME: these are raising wrong errors in Jython. self.assertRaises(ValueError, complex, "1j+2") self.assertRaises(ValueError, complex, "1e1ej") self.assertRaises(ValueError, complex, "1e++1ej") self.assertRaises(ValueError, complex, ")1+2j(") - # ] # the following three are accepted by Python 2.6 - #FIXME: these are raising wrong errors in Jython. self.assertRaises(ValueError, complex, "1..1j") self.assertRaises(ValueError, complex, "1.11.1j") self.assertRaises(ValueError, complex, "1e1.1j") - # ] - #FIXME: not working in Jython. if test_support.have_unicode: # check that complex accepts long unicode strings self.assertEqual(type(complex(unicode("1"*500))), complex) - # ] - - class EvilExc(Exception): - pass - - class evilcomplex: - def __complex__(self): - raise EvilExc - - self.assertRaises(EvilExc, complex, evilcomplex()) - - class float2: - def __init__(self, value): - self.value = value - def __float__(self): - return self.value - - self.assertAlmostEqual(complex(float2(42.)), 42) - self.assertAlmostEqual(complex(real=float2(17.), imag=float2(23.)), 17+23j) - self.assertRaises(TypeError, complex, float2(None)) - - class complex0(complex): - """Test usage of __complex__() when inheriting from 'complex'""" - def __complex__(self): - return 42j - - class complex1(complex): - """Test usage of __complex__() with a __new__() method""" - def __new__(self, value=0j): - return complex.__new__(self, 2*value) - def __complex__(self): - return self - - class complex2(complex): - """Make sure that __complex__() calls fail if anything other than a - complex is returned""" - def __complex__(self): - return None - - self.assertAlmostEqual(complex(complex0(1j)), 42j) - self.assertAlmostEqual(complex(complex1(1j)), 2j) - self.assertRaises(TypeError, complex, complex2(1j)) - - def test_constructor_jy(self): - # These are the parts of test_constructor that work in Jython. - # Delete this test when test_constructor skip is removed. - class OS: - def __init__(self, value): self.value = value - def __complex__(self): return self.value - class NS(object): - def __init__(self, value): self.value = value - def __complex__(self): return self.value - self.assertEqual(complex(OS(1+10j)), 1+10j) - self.assertEqual(complex(NS(1+10j)), 1+10j) - self.assertRaises(TypeError, complex, OS(None)) - self.assertRaises(TypeError, complex, NS(None)) - - self.assertAlmostEqual(complex("1+10j"), 1+10j) - self.assertAlmostEqual(complex(10), 10+0j) - self.assertAlmostEqual(complex(10.0), 10+0j) - self.assertAlmostEqual(complex(10L), 10+0j) - self.assertAlmostEqual(complex(10+0j), 10+0j) - self.assertAlmostEqual(complex(1,10), 1+10j) - self.assertAlmostEqual(complex(1,10L), 1+10j) - self.assertAlmostEqual(complex(1,10.0), 1+10j) - self.assertAlmostEqual(complex(1L,10), 1+10j) - self.assertAlmostEqual(complex(1L,10L), 1+10j) - self.assertAlmostEqual(complex(1L,10.0), 1+10j) - self.assertAlmostEqual(complex(1.0,10), 1+10j) - self.assertAlmostEqual(complex(1.0,10L), 1+10j) - self.assertAlmostEqual(complex(1.0,10.0), 1+10j) - self.assertAlmostEqual(complex(3.14+0j), 3.14+0j) - self.assertAlmostEqual(complex(3.14), 3.14+0j) - self.assertAlmostEqual(complex(314), 314.0+0j) - self.assertAlmostEqual(complex(314L), 314.0+0j) - self.assertAlmostEqual(complex(3.14+0j, 0j), 3.14+0j) - self.assertAlmostEqual(complex(3.14, 0.0), 3.14+0j) - self.assertAlmostEqual(complex(314, 0), 314.0+0j) - self.assertAlmostEqual(complex(314L, 0L), 314.0+0j) - self.assertAlmostEqual(complex(0j, 3.14j), -3.14+0j) - self.assertAlmostEqual(complex(0.0, 3.14j), -3.14+0j) - self.assertAlmostEqual(complex(0j, 3.14), 3.14j) - self.assertAlmostEqual(complex(0.0, 3.14), 3.14j) - self.assertAlmostEqual(complex("1"), 1+0j) - self.assertAlmostEqual(complex("1j"), 1j) - self.assertAlmostEqual(complex(), 0) - self.assertAlmostEqual(complex("-1"), -1) - self.assertAlmostEqual(complex("+1"), +1) - #FIXME: these are not working in Jython. - #self.assertAlmostEqual(complex("(1+2j)"), 1+2j) - #self.assertAlmostEqual(complex("(1.3+2.2j)"), 1.3+2.2j) - self.assertAlmostEqual(complex("3.14+1J"), 3.14+1j) - #FIXME: these are not working in Jython. - #self.assertAlmostEqual(complex(" ( +3.14-6J )"), 3.14-6j) - #self.assertAlmostEqual(complex(" ( +3.14-J )"), 3.14-1j) - #self.assertAlmostEqual(complex(" ( +3.14+j )"), 3.14+1j) - self.assertAlmostEqual(complex("J"), 1j) - #FIXME: this is not working in Jython. - #self.assertAlmostEqual(complex("( j )"), 1j) - self.assertAlmostEqual(complex("+J"), 1j) - #FIXME: this is not working in Jython. - #self.assertAlmostEqual(complex("( -j)"), -1j) - self.assertAlmostEqual(complex('1e-500'), 0.0 + 0.0j) - self.assertAlmostEqual(complex('-1e-500j'), 0.0 - 0.0j) - self.assertAlmostEqual(complex('-1e-500+1e-500j'), -0.0 + 0.0j) - - class complex2(complex): pass - self.assertAlmostEqual(complex(complex2(1+1j)), 1+1j) - self.assertAlmostEqual(complex(real=17, imag=23), 17+23j) - self.assertAlmostEqual(complex(real=17+23j), 17+23j) - self.assertAlmostEqual(complex(real=17+23j, imag=23), 17+46j) - self.assertAlmostEqual(complex(real=1+2j, imag=3+4j), -3+5j) - - # check that the sign of a zero in the real or imaginary part - # is preserved when constructing from two floats. (These checks - # are harmless on systems without support for signed zeros.) - def split_zeros(x): - """Function that produces different results for 0. and -0.""" - return atan2(x, -1.) - - self.assertEqual(split_zeros(complex(1., 0.).imag), split_zeros(0.)) - #FIXME: this is not working in Jython. - #self.assertEqual(split_zeros(complex(1., -0.).imag), split_zeros(-0.)) - self.assertEqual(split_zeros(complex(0., 1.).real), split_zeros(0.)) - self.assertEqual(split_zeros(complex(-0., 1.).real), split_zeros(-0.)) - - c = 3.14 + 1j - self.assertTrue(complex(c) is c) - del c - - self.assertRaises(TypeError, complex, "1", "1") - self.assertRaises(TypeError, complex, 1, "1") - - if test_support.have_unicode: - self.assertEqual(complex(unicode(" 3.14+J ")), 3.14+1j) - - # SF bug 543840: complex(string) accepts strings with \0 - # Fixed in 2.3. - self.assertRaises(ValueError, complex, '1+1j\0j') - - self.assertRaises(TypeError, int, 5+3j) - self.assertRaises(TypeError, long, 5+3j) - self.assertRaises(TypeError, float, 5+3j) - self.assertRaises(ValueError, complex, "") - self.assertRaises(TypeError, complex, None) - self.assertRaises(ValueError, complex, "\0") - self.assertRaises(ValueError, complex, "3\09") - self.assertRaises(TypeError, complex, "1", "2") - self.assertRaises(TypeError, complex, "1", 42) - self.assertRaises(TypeError, complex, 1, "2") - self.assertRaises(ValueError, complex, "1+") - self.assertRaises(ValueError, complex, "1+1j+1j") - self.assertRaises(ValueError, complex, "--") - self.assertRaises(ValueError, complex, "(1+2j") - self.assertRaises(ValueError, complex, "1+2j)") - self.assertRaises(ValueError, complex, "1+(2j)") - self.assertRaises(ValueError, complex, "(1+2j)123") - if test_support.have_unicode: - self.assertRaises(ValueError, complex, unicode("x")) - #FIXME: these are raising wrong errors in Jython. - #self.assertRaises(ValueError, complex, "1j+2") - #self.assertRaises(ValueError, complex, "1e1ej") - #self.assertRaises(ValueError, complex, "1e++1ej") - #self.assertRaises(ValueError, complex, ")1+2j(") - - # the following three are accepted by Python 2.6 - #FIXME: these are raising wrong errors in Jython. - #self.assertRaises(ValueError, complex, "1..1j") - #self.assertRaises(ValueError, complex, "1.11.1j") - #self.assertRaises(ValueError, complex, "1e1.1j") - - #FIXME: not working in Jython. - #if test_support.have_unicode: - # # check that complex accepts long unicode strings - # self.assertEqual(type(complex(unicode("1"*500))), complex) class EvilExc(Exception): pass @@ -641,7 +450,6 @@ for num in nums: self.assertAlmostEqual((num.real**2 + num.imag**2) ** 0.5, abs(num)) - @unittest.skipIf(test_support.is_jython, "FIXME: str.__complex__ not working in Jython") def test_repr(self): self.assertEqual(repr(1+6j), '(1+6j)') self.assertEqual(repr(1-6j), '(1-6j)') @@ -665,32 +473,6 @@ self.assertEqual(repr(complex(0, -INF)), "-infj") self.assertEqual(repr(complex(0, NAN)), "nanj") - def test_repr_jy(self): - # These are just the cases that Jython can do from test_repr - # Delete this test when test_repr passes - self.assertEqual(repr(1+6j), '(1+6j)') - self.assertEqual(repr(1-6j), '(1-6j)') - - self.assertNotEqual(repr(-(1+0j)), '(-1+-0j)') - - # Fails to round-trip: -# self.assertEqual(1-6j,complex(repr(1-6j))) -# self.assertEqual(1+6j,complex(repr(1+6j))) -# self.assertEqual(-6j,complex(repr(-6j))) -# self.assertEqual(6j,complex(repr(6j))) - - self.assertEqual(repr(complex(1., INF)), "(1+infj)") - self.assertEqual(repr(complex(1., -INF)), "(1-infj)") - self.assertEqual(repr(complex(INF, 1)), "(inf+1j)") - self.assertEqual(repr(complex(-INF, INF)), "(-inf+infj)") - self.assertEqual(repr(complex(NAN, 1)), "(nan+1j)") - self.assertEqual(repr(complex(1, NAN)), "(1+nanj)") - self.assertEqual(repr(complex(NAN, NAN)), "(nan+nanj)") - - self.assertEqual(repr(complex(0, INF)), "infj") - self.assertEqual(repr(complex(0, -INF)), "-infj") - self.assertEqual(repr(complex(0, NAN)), "nanj") - def test_neg(self): self.assertEqual(-(1+6j), -1-6j) @@ -729,7 +511,6 @@ @unittest.skipUnless(float.__getformat__("double").startswith("IEEE"), "test requires IEEE 754 doubles") - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") def test_overflow(self): self.assertEqual(complex("1e500"), complex(INF, 0.0)) self.assertEqual(complex("-1e500j"), complex(0.0, -INF)) @@ -737,7 +518,6 @@ @unittest.skipUnless(float.__getformat__("double").startswith("IEEE"), "test requires IEEE 754 doubles") - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") def test_repr_roundtrip(self): vals = [0.0, 1e-500, 1e-315, 1e-200, 0.0123, 3.1415, 1e50, INF, NAN] vals += [-v for v in vals] diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java --- a/src/org/python/core/PyString.java +++ b/src/org/python/core/PyString.java @@ -1021,141 +1021,9 @@ throw Py.TypeError("bad operand type for unary ~"); } - @SuppressWarnings("fallthrough") @Override public PyComplex __complex__() { - boolean got_re = false; - boolean got_im = false; - boolean done = false; - boolean sw_error = false; - - int s = 0; - int n = getString().length(); - while (s < n && Character.isSpaceChar(getString().charAt(s))) { - s++; - } - - if (s == n) { - throw Py.ValueError("empty string for complex()"); - } - - double z = -1.0; - double x = 0.0; - double y = 0.0; - - int sign = 1; - do { - char c = getString().charAt(s); - switch (c) { - case '-': - sign = -1; - /* Fallthrough */ - case '+': - if (done || s + 1 == n) { - sw_error = true; - break; - } - // a character is guaranteed, but it better be a digit - // or J or j - c = getString().charAt(++s); // eat the sign character - // and check the next - if (!Character.isDigit(c) && c != 'J' && c != 'j') { - sw_error = true; - } - break; - - case 'J': - case 'j': - if (got_im || done) { - sw_error = true; - break; - } - if (z < 0.0) { - y = sign; - } else { - y = sign * z; - } - got_im = true; - done = got_re; - sign = 1; - s++; // eat the J or j - break; - - case ' ': - while (s < n && Character.isSpaceChar(getString().charAt(s))) { - s++; - } - if (s != n) { - sw_error = true; - } - break; - - default: - boolean digit_or_dot = (c == '.' || Character.isDigit(c)); - if (!digit_or_dot) { - sw_error = true; - break; - } - int end = endDouble(getString(), s); - z = Double.valueOf(getString().substring(s, end)).doubleValue(); - if (z == Double.POSITIVE_INFINITY) { - throw Py.ValueError(String.format("float() out of range: %.150s", - getString())); - } - - s = end; - if (s < n) { - c = getString().charAt(s); - if (c == 'J' || c == 'j') { - break; - } - } - if (got_re) { - sw_error = true; - break; - } - - /* accept a real part */ - x = sign * z; - got_re = true; - done = got_im; - z = -1.0; - sign = 1; - break; - - } /* end of switch */ - - } while (s < n && !sw_error); - - if (sw_error) { - throw Py.ValueError("malformed string for complex() " + getString().substring(s)); - } - - return new PyComplex(x, y); - } - - private int endDouble(String string, int s) { - int n = string.length(); - while (s < n) { - char c = string.charAt(s++); - if (Character.isDigit(c)) { - continue; - } - if (c == '.') { - continue; - } - if (c == 'e' || c == 'E') { - if (s < n) { - c = string.charAt(s); - if (c == '+' || c == '-') { - s++; - } - continue; - } - } - return s - 1; - } - return s; + return atocx(); } // Add in methods from string module @@ -2718,45 +2586,80 @@ * @return the value */ public double atof() { - String bogus = null; double x = 0.0; - Matcher m = atofPattern.matcher(getString()); - - if (m.matches()) { - // Might be a valid float + Matcher m = getFloatPattern().matcher(getString()); + boolean valid = m.matches(); + + if (valid) { + // Might be a valid float: trimmed of white space in group 1. + String number = m.group(1); try { - if (m.group(3) == null) { - // No numeric part was found: it's something like "-Inf" or "hOrsE" + char lastChar = number.charAt(number.length() - 1); + if (Character.isLetter(lastChar)) { + // It's something like "nan", "-Inf" or "+nifty" x = atofSpecials(m.group(1)); } else { // A numeric part was present, try to convert the whole x = Double.parseDouble(m.group(1)); } } catch (NumberFormatException e) { - bogus = m.group(1); + valid = false; } - } else { - // This doesn't match the pattern for a float value - bogus = getString().trim(); } - // At this point, bogus will have been set to the trimmed string if there was a problem. - if (bogus == null) { + // At this point, valid will have been cleared if there was a problem. + if (valid) { return x; } else { - String fmt = "could not convert string to float: %s"; - throw Py.ValueError(String.format(fmt, bogus)); + String fmt = "invalid literal for float: %s"; + throw Py.ValueError(String.format(fmt, getString().trim())); } } /** - * Regular expression that includes all valid a Python float() arguments, in which group 1 - * captures the whole, stripped of white space, and group 3 will be present only if the form is - * numeric. Invalid non numerics are accepted ("+hOrsE" as "-inf"). + * Regular expression for an unsigned Python float, accepting also any sequence of the letters + * that belong to "NaN" or "Infinity" in whatever case. This is used within the regular + * expression patterns that define a priori acceptable strings in the float and complex + * constructors. The expression contributes no capture groups. */ - private static Pattern atofPattern = Pattern - .compile("\\s*([+-]?(((\\d+(\\.\\d*)?|\\.\\d+)([eE][+-]?\\d+)?)|\\p{Alpha}+))\\s*"); + private static final String UF_RE = + "(?:(?:(?:\\d+\\.?|\\.\\d)\\d*(?:[eE][+-]?\\d+)?)|[infatyINFATY]+)"; + + /** + * Return the (lazily) compiled regular expression that matches all valid a Python float() + * arguments, in which Group 1 captures the number, stripped of white space. Various invalid + * non-numerics are provisionally accepted (e.g. "+inanity" or "-faint"). + */ + private static synchronized Pattern getFloatPattern() { + if (floatPattern == null) { + floatPattern = Pattern.compile("\\s*([+-]?" + UF_RE + ")\\s*"); + } + return floatPattern; + } + + /** Access only through {@link #getFloatPattern()}. */ + private static Pattern floatPattern = null; + + /** + * Return the (lazily) compiled regular expression for a Python complex number. This is used + * within the regular expression patterns that define a priori acceptable strings in the complex + * constructors. The expression contributes five named capture groups a, b, x, y and j. x and y + * are the two floats encountered, and if j is present, one of them is the imaginary part. + * a and b are the optional parentheses. They must either both be present or both omitted. + */ + private static synchronized Pattern getComplexPattern() { + if (complexPattern == null) { + complexPattern = Pattern.compile("\\s*(?\\(\\s*)?" // Parenthesis + + "(?[+-]?" + UF_RE + "?)" // + + "(?[+-]" + UF_RE + "?)?(?[jJ])?" // + + + "\\s*(?\\)\\s*)?"); // Parenthesis + } + return complexPattern; + } + + /** Access only through {@link #getComplexPattern()} */ + private static Pattern complexPattern = null; /** * Conversion for non-numeric floats, accepting signed or unsigned "inf" and "nan", in any case. @@ -2784,6 +2687,99 @@ } } + /** + * Convert this PyString to a complex value according to Python rules. + * + * @return the value + */ + private PyComplex atocx() { + double x = 0.0, y = 0.0; + Matcher m = getComplexPattern().matcher(getString()); + boolean valid = m.matches(); + + if (valid) { + // Passes a priori, but we have some checks to make. Brackets: both or neither. + if ((m.group("a") != null) != (m.group("b") != null)) { + valid = false; + + } else { + try { + // Pick up the two numbers [+-]? [+-] j? + String xs = m.group("x"), ys = m.group("y"); + + if (m.group("j") != null) { + // There is a 'j', so there is an imaginary part. + if (ys != null) { + // There were two numbers, so the second is the imaginary part. + y = toComplexPart(ys); + // And the first is the real part + x = toComplexPart(xs); + } else if (xs != null) { + // There was only one number (and a 'j')so it is the imaginary part. + y = toComplexPart(xs); + // x = 0.0; + } else { + // There were no numbers, just the 'j'. (Impossible return?) + y = 1.0; + // x = 0.0; + } + + } else { + // There is no 'j' so can only be one number, the real part. + x = Double.parseDouble(xs); + if (ys != null) { + // Something like "123 +" or "123 + 456" but no 'j'. + throw new NumberFormatException(); + } + } + + } catch (NumberFormatException e) { + valid = false; + } + } + } + + // At this point, valid will have been cleared if there was a problem. + if (valid) { + return new PyComplex(x, y); + } else { + String fmt = "complex() arg is a malformed string: %s"; + throw Py.ValueError(String.format(fmt, getString().trim())); + } + + } + + /** + * Helper for interpreting each part (real and imaginary) of a complex number expressed as a + * string in {@link #atocx(String)}. It deals with numbers, inf, nan and their variants, and + * with the "implied one" in +j or 10-j. + * + * @param s to interpret + * @return value of s + * @throws NumberFormatException if the number is invalid + */ + private static double toComplexPart(String s) throws NumberFormatException { + if (s.length() == 0) { + // Empty string (occurs only as 'j') + return 1.0; + } else { + char lastChar = s.charAt(s.length() - 1); + if (Character.isLetter(lastChar)) { + // Possibly a sign, then letters that ought to be "nan" or "inf[inity]" + return atofSpecials(s); + } else if (lastChar == '+') { + // Occurs only as "+j" + return 1.0; + } else if (lastChar == '-') { + // Occurs only as "-j" + return -1.0; + } else { + // Possibly a sign then an unsigned float + return Double.parseDouble(s); + } + } + } + private BigInteger asciiToBigInteger(int base, boolean isLong) { String str = getString(); -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 31 02:41:07 2014 From: jython-checkins at python.org (jeff.allen) Date: Wed, 31 Dec 2014 01:41:07 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Fixes_comparison_of_complex?= =?utf-8?q?_with_other_types=2E?= Message-ID: <20141231014105.22198.38987@psf.io> https://hg.python.org/jython/rev/e5bd019ce004 changeset: 7485:e5bd019ce004 user: Jeff Allen date: Sun Dec 21 17:03:12 2014 +0000 summary: Fixes comparison of complex with other types. These were failures in test_complex. Here endeth our need for a Jython-specific version of that. files: Lib/test/test_complex.py | 654 ----------------- src/org/python/core/PyComplex.java | 81 +- 2 files changed, 61 insertions(+), 674 deletions(-) diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py deleted file mode 100644 --- a/Lib/test/test_complex.py +++ /dev/null @@ -1,654 +0,0 @@ -import unittest -from test import test_support - -from random import random -from math import atan2, isnan, copysign - -INF = float("inf") -NAN = float("nan") -# These tests ensure that complex math does the right thing - -class ComplexTest(unittest.TestCase): - - def assertAlmostEqual(self, a, b): - if isinstance(a, complex): - if isinstance(b, complex): - unittest.TestCase.assertAlmostEqual(self, a.real, b.real) - unittest.TestCase.assertAlmostEqual(self, a.imag, b.imag) - else: - unittest.TestCase.assertAlmostEqual(self, a.real, b) - unittest.TestCase.assertAlmostEqual(self, a.imag, 0.) - else: - if isinstance(b, complex): - unittest.TestCase.assertAlmostEqual(self, a, b.real) - unittest.TestCase.assertAlmostEqual(self, 0., b.imag) - else: - unittest.TestCase.assertAlmostEqual(self, a, b) - - def assertCloseAbs(self, x, y, eps=1e-9): - """Return true iff floats x and y "are close\"""" - # put the one with larger magnitude second - if abs(x) > abs(y): - x, y = y, x - if y == 0: - return abs(x) < eps - if x == 0: - return abs(y) < eps - # check that relative difference < eps - self.assertTrue(abs((x-y)/y) < eps) - - def assertFloatsAreIdentical(self, x, y): - """assert that floats x and y are identical, in the sense that: - (1) both x and y are nans, or - (2) both x and y are infinities, with the same sign, or - (3) both x and y are zeros, with the same sign, or - (4) x and y are both finite and nonzero, and x == y - - """ - msg = 'floats {!r} and {!r} are not identical' - - if isnan(x) or isnan(y): - if isnan(x) and isnan(y): - return - elif x == y: - if x != 0.0: - return - # both zero; check that signs match - elif copysign(1.0, x) == copysign(1.0, y): - return - else: - msg += ': zeros have different signs' - self.fail(msg.format(x, y)) - - def assertClose(self, x, y, eps=1e-9): - """Return true iff complexes x and y "are close\"""" - self.assertCloseAbs(x.real, y.real, eps) - self.assertCloseAbs(x.imag, y.imag, eps) - - def check_div(self, x, y): - """Compute complex z=x*y, and check that z/x==y and z/y==x.""" - z = x * y - if x != 0: - q = z / x - self.assertClose(q, y) - q = z.__div__(x) - self.assertClose(q, y) - q = z.__truediv__(x) - self.assertClose(q, y) - if y != 0: - q = z / y - self.assertClose(q, x) - q = z.__div__(y) - self.assertClose(q, x) - q = z.__truediv__(y) - self.assertClose(q, x) - - def test_div(self): - simple_real = [float(i) for i in xrange(-5, 6)] - simple_complex = [complex(x, y) for x in simple_real for y in simple_real] - for x in simple_complex: - for y in simple_complex: - self.check_div(x, y) - - # A naive complex division algorithm (such as in 2.0) is very prone to - # nonsense errors for these (overflows and underflows). - self.check_div(complex(1e200, 1e200), 1+0j) - self.check_div(complex(1e-200, 1e-200), 1+0j) - - # Just for fun. - for i in xrange(100): - self.check_div(complex(random(), random()), - complex(random(), random())) - - self.assertRaises(ZeroDivisionError, complex.__div__, 1+1j, 0+0j) - # FIXME: The following currently crashes on Alpha - # self.assertRaises(OverflowError, pow, 1e200+1j, 1e200+1j) - - def test_truediv(self): - self.assertAlmostEqual(complex.__truediv__(2+0j, 1+1j), 1-1j) - self.assertRaises(ZeroDivisionError, complex.__truediv__, 1+1j, 0+0j) - - def test_floordiv(self): - self.assertAlmostEqual(complex.__floordiv__(3+0j, 1.5+0j), 2) - self.assertRaises(ZeroDivisionError, complex.__floordiv__, 3+0j, 0+0j) - - def test_coerce(self): - self.assertRaises(OverflowError, complex.__coerce__, 1+1j, 1L<<10000) - - def test_no_implicit_coerce(self): - # Python 2.7 removed implicit coercion from the complex type - class A(object): - def __coerce__(self, other): - raise RuntimeError - __hash__ = None - def __cmp__(self, other): - return -1 - - a = A() - self.assertRaises(TypeError, lambda: a + 2.0j) - self.assertTrue(a < 2.0j) - - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") - def test_richcompare(self): - self.assertEqual(complex.__eq__(1+1j, 1L<<10000), False) - self.assertEqual(complex.__lt__(1+1j, None), NotImplemented) - self.assertIs(complex.__eq__(1+1j, 1+1j), True) - self.assertIs(complex.__eq__(1+1j, 2+2j), False) - self.assertIs(complex.__ne__(1+1j, 1+1j), False) - self.assertIs(complex.__ne__(1+1j, 2+2j), True) - self.assertRaises(TypeError, complex.__lt__, 1+1j, 2+2j) - self.assertRaises(TypeError, complex.__le__, 1+1j, 2+2j) - self.assertRaises(TypeError, complex.__gt__, 1+1j, 2+2j) - self.assertRaises(TypeError, complex.__ge__, 1+1j, 2+2j) - - @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython") - def test_richcompare_boundaries(self): - def check(n, deltas, is_equal, imag = 0.0): - for delta in deltas: - i = n + delta - z = complex(i, imag) - self.assertIs(complex.__eq__(z, i), is_equal(delta)) - self.assertIs(complex.__ne__(z, i), not is_equal(delta)) - # For IEEE-754 doubles the following should hold: - # x in [2 ** (52 + i), 2 ** (53 + i + 1)] -> x mod 2 ** i == 0 - # where the interval is representable, of course. - for i in range(1, 10): - pow = 52 + i - mult = 2 ** i - check(2 ** pow, range(1, 101), lambda delta: delta % mult == 0) - check(2 ** pow, range(1, 101), lambda delta: False, float(i)) - check(2 ** 53, range(-100, 0), lambda delta: True) - - def test_mod(self): - self.assertRaises(ZeroDivisionError, (1+1j).__mod__, 0+0j) - - a = 3.33+4.43j - try: - a % 0 - except ZeroDivisionError: - pass - else: - self.fail("modulo parama can't be 0") - - def test_divmod(self): - self.assertRaises(ZeroDivisionError, divmod, 1+1j, 0+0j) - - def test_pow(self): - self.assertAlmostEqual(pow(1+1j, 0+0j), 1.0) - self.assertAlmostEqual(pow(0+0j, 2+0j), 0.0) - self.assertRaises(ZeroDivisionError, pow, 0+0j, 1j) - self.assertAlmostEqual(pow(1j, -1), 1/1j) - self.assertAlmostEqual(pow(1j, 200), 1) - self.assertRaises(ValueError, pow, 1+1j, 1+1j, 1+1j) - - a = 3.33+4.43j - self.assertEqual(a ** 0j, 1) - self.assertEqual(a ** 0.+0.j, 1) - - self.assertEqual(3j ** 0j, 1) - self.assertEqual(3j ** 0, 1) - - try: - 0j ** a - except ZeroDivisionError: - pass - else: - self.fail("should fail 0.0 to negative or complex power") - - try: - 0j ** (3-2j) - except ZeroDivisionError: - pass - else: - self.fail("should fail 0.0 to negative or complex power") - - # The following is used to exercise certain code paths - self.assertEqual(a ** 105, a ** 105) - self.assertEqual(a ** -105, a ** -105) - self.assertEqual(a ** -30, a ** -30) - - self.assertEqual(0.0j ** 0, 1) - - b = 5.1+2.3j - self.assertRaises(ValueError, pow, a, b, 0) - - def test_boolcontext(self): - for i in xrange(100): - self.assertTrue(complex(random() + 1e-6, random() + 1e-6)) - self.assertTrue(not complex(0.0, 0.0)) - - def test_conjugate(self): - self.assertClose(complex(5.3, 9.8).conjugate(), 5.3-9.8j) - - def test_constructor(self): - class OS: - def __init__(self, value): self.value = value - def __complex__(self): return self.value - class NS(object): - def __init__(self, value): self.value = value - def __complex__(self): return self.value - self.assertEqual(complex(OS(1+10j)), 1+10j) - self.assertEqual(complex(NS(1+10j)), 1+10j) - self.assertRaises(TypeError, complex, OS(None)) - self.assertRaises(TypeError, complex, NS(None)) - - self.assertAlmostEqual(complex("1+10j"), 1+10j) - self.assertAlmostEqual(complex(10), 10+0j) - self.assertAlmostEqual(complex(10.0), 10+0j) - self.assertAlmostEqual(complex(10L), 10+0j) - self.assertAlmostEqual(complex(10+0j), 10+0j) - self.assertAlmostEqual(complex(1,10), 1+10j) - self.assertAlmostEqual(complex(1,10L), 1+10j) - self.assertAlmostEqual(complex(1,10.0), 1+10j) - self.assertAlmostEqual(complex(1L,10), 1+10j) - self.assertAlmostEqual(complex(1L,10L), 1+10j) - self.assertAlmostEqual(complex(1L,10.0), 1+10j) - self.assertAlmostEqual(complex(1.0,10), 1+10j) - self.assertAlmostEqual(complex(1.0,10L), 1+10j) - self.assertAlmostEqual(complex(1.0,10.0), 1+10j) - self.assertAlmostEqual(complex(3.14+0j), 3.14+0j) - self.assertAlmostEqual(complex(3.14), 3.14+0j) - self.assertAlmostEqual(complex(314), 314.0+0j) - self.assertAlmostEqual(complex(314L), 314.0+0j) - self.assertAlmostEqual(complex(3.14+0j, 0j), 3.14+0j) - self.assertAlmostEqual(complex(3.14, 0.0), 3.14+0j) - self.assertAlmostEqual(complex(314, 0), 314.0+0j) - self.assertAlmostEqual(complex(314L, 0L), 314.0+0j) - self.assertAlmostEqual(complex(0j, 3.14j), -3.14+0j) - self.assertAlmostEqual(complex(0.0, 3.14j), -3.14+0j) - self.assertAlmostEqual(complex(0j, 3.14), 3.14j) - self.assertAlmostEqual(complex(0.0, 3.14), 3.14j) - self.assertAlmostEqual(complex("1"), 1+0j) - self.assertAlmostEqual(complex("1j"), 1j) - self.assertAlmostEqual(complex(), 0) - self.assertAlmostEqual(complex("-1"), -1) - self.assertAlmostEqual(complex("+1"), +1) - self.assertAlmostEqual(complex("(1+2j)"), 1+2j) - self.assertAlmostEqual(complex("(1.3+2.2j)"), 1.3+2.2j) - self.assertAlmostEqual(complex("3.14+1J"), 3.14+1j) - self.assertAlmostEqual(complex(" ( +3.14-6J )"), 3.14-6j) - self.assertAlmostEqual(complex(" ( +3.14-J )"), 3.14-1j) - self.assertAlmostEqual(complex(" ( +3.14+j )"), 3.14+1j) - self.assertAlmostEqual(complex("J"), 1j) - self.assertAlmostEqual(complex("( j )"), 1j) - self.assertAlmostEqual(complex("+J"), 1j) - self.assertAlmostEqual(complex("( -j)"), -1j) - self.assertAlmostEqual(complex('1e-500'), 0.0 + 0.0j) - self.assertAlmostEqual(complex('-1e-500j'), 0.0 - 0.0j) - self.assertAlmostEqual(complex('-1e-500+1e-500j'), -0.0 + 0.0j) - - class complex2(complex): pass - self.assertAlmostEqual(complex(complex2(1+1j)), 1+1j) - self.assertAlmostEqual(complex(real=17, imag=23), 17+23j) - self.assertAlmostEqual(complex(real=17+23j), 17+23j) - self.assertAlmostEqual(complex(real=17+23j, imag=23), 17+46j) - self.assertAlmostEqual(complex(real=1+2j, imag=3+4j), -3+5j) - - # check that the sign of a zero in the real or imaginary part - # is preserved when constructing from two floats. (These checks - # are harmless on systems without support for signed zeros.) - def split_zeros(x): - """Function that produces different results for 0. and -0.""" - return atan2(x, -1.) - - self.assertEqual(split_zeros(complex(1., 0.).imag), split_zeros(0.)) - self.assertEqual(split_zeros(complex(1., -0.).imag), split_zeros(-0.)) - self.assertEqual(split_zeros(complex(0., 1.).real), split_zeros(0.)) - self.assertEqual(split_zeros(complex(-0., 1.).real), split_zeros(-0.)) - - c = 3.14 + 1j - self.assertTrue(complex(c) is c) - del c - - self.assertRaises(TypeError, complex, "1", "1") - self.assertRaises(TypeError, complex, 1, "1") - - if test_support.have_unicode: - self.assertEqual(complex(unicode(" 3.14+J ")), 3.14+1j) - - # SF bug 543840: complex(string) accepts strings with \0 - # Fixed in 2.3. - self.assertRaises(ValueError, complex, '1+1j\0j') - - self.assertRaises(TypeError, int, 5+3j) - self.assertRaises(TypeError, long, 5+3j) - self.assertRaises(TypeError, float, 5+3j) - self.assertRaises(ValueError, complex, "") - self.assertRaises(TypeError, complex, None) - self.assertRaises(ValueError, complex, "\0") - self.assertRaises(ValueError, complex, "3\09") - self.assertRaises(TypeError, complex, "1", "2") - self.assertRaises(TypeError, complex, "1", 42) - self.assertRaises(TypeError, complex, 1, "2") - self.assertRaises(ValueError, complex, "1+") - self.assertRaises(ValueError, complex, "1+1j+1j") - self.assertRaises(ValueError, complex, "--") - self.assertRaises(ValueError, complex, "(1+2j") - self.assertRaises(ValueError, complex, "1+2j)") - self.assertRaises(ValueError, complex, "1+(2j)") - self.assertRaises(ValueError, complex, "(1+2j)123") - if test_support.have_unicode: - self.assertRaises(ValueError, complex, unicode("x")) - self.assertRaises(ValueError, complex, "1j+2") - self.assertRaises(ValueError, complex, "1e1ej") - self.assertRaises(ValueError, complex, "1e++1ej") - self.assertRaises(ValueError, complex, ")1+2j(") - - # the following three are accepted by Python 2.6 - self.assertRaises(ValueError, complex, "1..1j") - self.assertRaises(ValueError, complex, "1.11.1j") - self.assertRaises(ValueError, complex, "1e1.1j") - - if test_support.have_unicode: - # check that complex accepts long unicode strings - self.assertEqual(type(complex(unicode("1"*500))), complex) - - class EvilExc(Exception): - pass - - class evilcomplex: - def __complex__(self): - raise EvilExc - - self.assertRaises(EvilExc, complex, evilcomplex()) - - class float2: - def __init__(self, value): - self.value = value - def __float__(self): - return self.value - - self.assertAlmostEqual(complex(float2(42.)), 42) - self.assertAlmostEqual(complex(real=float2(17.), imag=float2(23.)), 17+23j) - self.assertRaises(TypeError, complex, float2(None)) - - class complex0(complex): - """Test usage of __complex__() when inheriting from 'complex'""" - def __complex__(self): - return 42j - - class complex1(complex): - """Test usage of __complex__() with a __new__() method""" - def __new__(self, value=0j): - return complex.__new__(self, 2*value) - def __complex__(self): - return self - - class complex2(complex): - """Make sure that __complex__() calls fail if anything other than a - complex is returned""" - def __complex__(self): - return None - - self.assertAlmostEqual(complex(complex0(1j)), 42j) - self.assertAlmostEqual(complex(complex1(1j)), 2j) - self.assertRaises(TypeError, complex, complex2(1j)) - - def test_subclass(self): - class xcomplex(complex): - def __add__(self,other): - return xcomplex(complex(self) + other) - __radd__ = __add__ - - def __sub__(self,other): - return xcomplex(complex(self) + other) - __rsub__ = __sub__ - - def __mul__(self,other): - return xcomplex(complex(self) * other) - __rmul__ = __mul__ - - def __div__(self,other): - return xcomplex(complex(self) / other) - - def __rdiv__(self,other): - return xcomplex(other / complex(self)) - - __truediv__ = __div__ - __rtruediv__ = __rdiv__ - - def __floordiv__(self,other): - return xcomplex(complex(self) // other) - - def __rfloordiv__(self,other): - return xcomplex(other // complex(self)) - - def __pow__(self,other): - return xcomplex(complex(self) ** other) - - def __rpow__(self,other): - return xcomplex(other ** complex(self) ) - - def __mod__(self,other): - return xcomplex(complex(self) % other) - - def __rmod__(self,other): - return xcomplex(other % complex(self)) - - infix_binops = ('+', '-', '*', '**', '%', '//', '/') - xcomplex_values = (xcomplex(1), xcomplex(123.0), - xcomplex(-10+2j), xcomplex(3+187j), - xcomplex(3-78j)) - test_values = (1, 123.0, 10-19j, xcomplex(1+2j), - xcomplex(1+87j), xcomplex(10+90j)) - - for op in infix_binops: - for x in xcomplex_values: - for y in test_values: - a = 'x %s y' % op - b = 'y %s x' % op - self.assertTrue(type(eval(a)) is type(eval(b)) is xcomplex) - - def test_hash(self): - for x in xrange(-30, 30): - self.assertEqual(hash(x), hash(complex(x, 0))) - x /= 3.0 # now check against floating point - self.assertEqual(hash(x), hash(complex(x, 0.))) - - def test_abs(self): - nums = [complex(x/3., y/7.) for x in xrange(-9,9) for y in xrange(-9,9)] - for num in nums: - self.assertAlmostEqual((num.real**2 + num.imag**2) ** 0.5, abs(num)) - - def test_repr(self): - self.assertEqual(repr(1+6j), '(1+6j)') - self.assertEqual(repr(1-6j), '(1-6j)') - - self.assertNotEqual(repr(-(1+0j)), '(-1+-0j)') - - self.assertEqual(1-6j,complex(repr(1-6j))) - self.assertEqual(1+6j,complex(repr(1+6j))) - self.assertEqual(-6j,complex(repr(-6j))) - self.assertEqual(6j,complex(repr(6j))) - - self.assertEqual(repr(complex(1., INF)), "(1+infj)") - self.assertEqual(repr(complex(1., -INF)), "(1-infj)") - self.assertEqual(repr(complex(INF, 1)), "(inf+1j)") - self.assertEqual(repr(complex(-INF, INF)), "(-inf+infj)") - self.assertEqual(repr(complex(NAN, 1)), "(nan+1j)") - self.assertEqual(repr(complex(1, NAN)), "(1+nanj)") - self.assertEqual(repr(complex(NAN, NAN)), "(nan+nanj)") - - self.assertEqual(repr(complex(0, INF)), "infj") - self.assertEqual(repr(complex(0, -INF)), "-infj") - self.assertEqual(repr(complex(0, NAN)), "nanj") - - def test_neg(self): - self.assertEqual(-(1+6j), -1-6j) - - def test_file(self): - a = 3.33+4.43j - b = 5.1+2.3j - - fo = None - try: - fo = open(test_support.TESTFN, "wb") - print >>fo, a, b - fo.close() - fo = open(test_support.TESTFN, "rb") - self.assertEqual(fo.read(), "%s %s\n" % (a, b)) - finally: - if (fo is not None) and (not fo.closed): - fo.close() - test_support.unlink(test_support.TESTFN) - - def test_getnewargs(self): - self.assertEqual((1+2j).__getnewargs__(), (1.0, 2.0)) - self.assertEqual((1-2j).__getnewargs__(), (1.0, -2.0)) - self.assertEqual((2j).__getnewargs__(), (0.0, 2.0)) - self.assertEqual((-0j).__getnewargs__(), (0.0, -0.0)) - self.assertEqual(complex(0, INF).__getnewargs__(), (0.0, INF)) - self.assertEqual(complex(INF, 0).__getnewargs__(), (INF, 0.0)) - - if float.__getformat__("double").startswith("IEEE"): - def test_plus_minus_0j(self): - # test that -0j and 0j literals are not identified - z1, z2 = 0j, -0j - self.assertEqual(atan2(z1.imag, -1.), atan2(0., -1.)) - self.assertEqual(atan2(z2.imag, -1.), atan2(-0., -1.)) - - @unittest.skipUnless(float.__getformat__("double").startswith("IEEE"), - "test requires IEEE 754 doubles") - def test_overflow(self): - self.assertEqual(complex("1e500"), complex(INF, 0.0)) - self.assertEqual(complex("-1e500j"), complex(0.0, -INF)) - self.assertEqual(complex("-1e500+1.8e308j"), complex(-INF, INF)) - - @unittest.skipUnless(float.__getformat__("double").startswith("IEEE"), - "test requires IEEE 754 doubles") - def test_repr_roundtrip(self): - vals = [0.0, 1e-500, 1e-315, 1e-200, 0.0123, 3.1415, 1e50, INF, NAN] - vals += [-v for v in vals] - - # complex(repr(z)) should recover z exactly, even for complex - # numbers involving an infinity, nan, or negative zero - for x in vals: - for y in vals: - z = complex(x, y) - roundtrip = complex(repr(z)) - self.assertFloatsAreIdentical(z.real, roundtrip.real) - self.assertFloatsAreIdentical(z.imag, roundtrip.imag) - - # if we predefine some constants, then eval(repr(z)) should - # also work, except that it might change the sign of zeros - inf, nan = float('inf'), float('nan') - infj, nanj = complex(0.0, inf), complex(0.0, nan) - for x in vals: - for y in vals: - z = complex(x, y) - roundtrip = eval(repr(z)) - # adding 0.0 has no effect beside changing -0.0 to 0.0 - self.assertFloatsAreIdentical(0.0 + z.real, - 0.0 + roundtrip.real) - self.assertFloatsAreIdentical(0.0 + z.imag, - 0.0 + roundtrip.imag) - - def test_format(self): - # empty format string is same as str() - self.assertEqual(format(1+3j, ''), str(1+3j)) - self.assertEqual(format(1.5+3.5j, ''), str(1.5+3.5j)) - self.assertEqual(format(3j, ''), str(3j)) - self.assertEqual(format(3.2j, ''), str(3.2j)) - self.assertEqual(format(3+0j, ''), str(3+0j)) - self.assertEqual(format(3.2+0j, ''), str(3.2+0j)) - - # empty presentation type should still be analogous to str, - # even when format string is nonempty (issue #5920). - self.assertEqual(format(3.2+0j, '-'), str(3.2+0j)) - self.assertEqual(format(3.2+0j, '<'), str(3.2+0j)) - z = 4/7. - 100j/7. - self.assertEqual(format(z, ''), str(z)) - self.assertEqual(format(z, '-'), str(z)) - self.assertEqual(format(z, '<'), str(z)) - self.assertEqual(format(z, '10'), str(z)) - z = complex(0.0, 3.0) - self.assertEqual(format(z, ''), str(z)) - self.assertEqual(format(z, '-'), str(z)) - self.assertEqual(format(z, '<'), str(z)) - self.assertEqual(format(z, '2'), str(z)) - z = complex(-0.0, 2.0) - self.assertEqual(format(z, ''), str(z)) - self.assertEqual(format(z, '-'), str(z)) - self.assertEqual(format(z, '<'), str(z)) - self.assertEqual(format(z, '3'), str(z)) - - self.assertEqual(format(1+3j, 'g'), '1+3j') - self.assertEqual(format(3j, 'g'), '0+3j') - self.assertEqual(format(1.5+3.5j, 'g'), '1.5+3.5j') - - self.assertEqual(format(1.5+3.5j, '+g'), '+1.5+3.5j') - self.assertEqual(format(1.5-3.5j, '+g'), '+1.5-3.5j') - self.assertEqual(format(1.5-3.5j, '-g'), '1.5-3.5j') - self.assertEqual(format(1.5+3.5j, ' g'), ' 1.5+3.5j') - self.assertEqual(format(1.5-3.5j, ' g'), ' 1.5-3.5j') - self.assertEqual(format(-1.5+3.5j, ' g'), '-1.5+3.5j') - self.assertEqual(format(-1.5-3.5j, ' g'), '-1.5-3.5j') - - self.assertEqual(format(-1.5-3.5e-20j, 'g'), '-1.5-3.5e-20j') - self.assertEqual(format(-1.5-3.5j, 'f'), '-1.500000-3.500000j') - self.assertEqual(format(-1.5-3.5j, 'F'), '-1.500000-3.500000j') - self.assertEqual(format(-1.5-3.5j, 'e'), '-1.500000e+00-3.500000e+00j') - self.assertEqual(format(-1.5-3.5j, '.2e'), '-1.50e+00-3.50e+00j') - self.assertEqual(format(-1.5-3.5j, '.2E'), '-1.50E+00-3.50E+00j') - self.assertEqual(format(-1.5e10-3.5e5j, '.2G'), '-1.5E+10-3.5E+05j') - - self.assertEqual(format(1.5+3j, '<20g'), '1.5+3j ') - self.assertEqual(format(1.5+3j, '*<20g'), '1.5+3j**************') - self.assertEqual(format(1.5+3j, '>20g'), ' 1.5+3j') - self.assertEqual(format(1.5+3j, '^20g'), ' 1.5+3j ') - self.assertEqual(format(1.5+3j, '<20'), '(1.5+3j) ') - self.assertEqual(format(1.5+3j, '>20'), ' (1.5+3j)') - self.assertEqual(format(1.5+3j, '^20'), ' (1.5+3j) ') - self.assertEqual(format(1.123-3.123j, '^20.2'), ' (1.1-3.1j) ') - - self.assertEqual(format(1.5+3j, '20.2f'), ' 1.50+3.00j') - self.assertEqual(format(1.5+3j, '>20.2f'), ' 1.50+3.00j') - self.assertEqual(format(1.5+3j, '<20.2f'), '1.50+3.00j ') - self.assertEqual(format(1.5e20+3j, '<20.2f'), '150000000000000000000.00+3.00j') - self.assertEqual(format(1.5e20+3j, '>40.2f'), ' 150000000000000000000.00+3.00j') - self.assertEqual(format(1.5e20+3j, '^40,.2f'), ' 150,000,000,000,000,000,000.00+3.00j ') - self.assertEqual(format(1.5e21+3j, '^40,.2f'), ' 1,500,000,000,000,000,000,000.00+3.00j ') - self.assertEqual(format(1.5e21+3000j, ',.2f'), '1,500,000,000,000,000,000,000.00+3,000.00j') - - # alternate is invalid - self.assertRaises(ValueError, (1.5+0.5j).__format__, '#f') - - # zero padding is invalid - self.assertRaises(ValueError, (1.5+0.5j).__format__, '010f') - - # '=' alignment is invalid - self.assertRaises(ValueError, (1.5+3j).__format__, '=20') - - # integer presentation types are an error - for t in 'bcdoxX': - self.assertRaises(ValueError, (1.5+0.5j).__format__, t) - - # make sure everything works in ''.format() - self.assertEqual('*{0:.3f}*'.format(3.14159+2.71828j), '*3.142+2.718j*') - - # issue 3382: 'f' and 'F' with inf's and nan's - self.assertEqual('{0:f}'.format(INF+0j), 'inf+0.000000j') - self.assertEqual('{0:F}'.format(INF+0j), 'INF+0.000000j') - self.assertEqual('{0:f}'.format(-INF+0j), '-inf+0.000000j') - self.assertEqual('{0:F}'.format(-INF+0j), '-INF+0.000000j') - self.assertEqual('{0:f}'.format(complex(INF, INF)), 'inf+infj') - self.assertEqual('{0:F}'.format(complex(INF, INF)), 'INF+INFj') - self.assertEqual('{0:f}'.format(complex(INF, -INF)), 'inf-infj') - self.assertEqual('{0:F}'.format(complex(INF, -INF)), 'INF-INFj') - self.assertEqual('{0:f}'.format(complex(-INF, INF)), '-inf+infj') - self.assertEqual('{0:F}'.format(complex(-INF, INF)), '-INF+INFj') - self.assertEqual('{0:f}'.format(complex(-INF, -INF)), '-inf-infj') - self.assertEqual('{0:F}'.format(complex(-INF, -INF)), '-INF-INFj') - - self.assertEqual('{0:f}'.format(complex(NAN, 0)), 'nan+0.000000j') - self.assertEqual('{0:F}'.format(complex(NAN, 0)), 'NAN+0.000000j') - self.assertEqual('{0:f}'.format(complex(NAN, NAN)), 'nan+nanj') - self.assertEqual('{0:F}'.format(complex(NAN, NAN)), 'NAN+NANj') - -def test_main(): - with test_support.check_warnings(("complex divmod.., // and % are " - "deprecated", DeprecationWarning)): - test_support.run_unittest(ComplexTest) - -if __name__ == "__main__": - test_main() diff --git a/src/org/python/core/PyComplex.java b/src/org/python/core/PyComplex.java --- a/src/org/python/core/PyComplex.java +++ b/src/org/python/core/PyComplex.java @@ -27,7 +27,8 @@ static PyComplex J = new PyComplex(0, 1.); - public static final PyComplex Inf = new PyComplex(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY); + public static final PyComplex Inf = new PyComplex(Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY); public static final PyComplex NaN = new PyComplex(Double.NaN, Double.NaN); @ExposedGet(doc = BuiltinDocs.complex_real_doc) @@ -247,11 +248,53 @@ @ExposedMethod(type = MethodType.BINARY, doc = BuiltinDocs.complex___eq___doc) final PyObject complex___eq__(PyObject other) { - if (!canCoerce(other)) { - return null; + switch (eq_helper(other)) { + case 0: + return Py.False; + case 1: + return Py.True; + default: + return null; } - PyComplex c = coerce(other); - return Py.newBoolean(real == c.real && imag == c.imag); + } + + /** + * Helper for {@link #complex___eq__(PyObject)} and {@link #complex___ne__(PyObject)}. + * + * @param other to compare for equality with this + * @return 0 = false, 1 = true, 2 = don't know (ask the other object) + */ + private int eq_helper(PyObject other) { + // We only deal with primitive types here. All others delegate upwards (return 2). + boolean equal; + if (other instanceof PyComplex) { + PyComplex c = ((PyComplex)other); + equal = (this.real == c.real && this.imag == c.imag); + } else if (other instanceof PyFloat) { + PyFloat f = ((PyFloat)other); + equal = (this.imag == 0.0 && this.real == f.getValue()); + } else if (other instanceof PyInteger || other instanceof PyLong) { + if (this.imag == 0.0) { + // The imaginary part is zero: other object primitive might equal the real part. + double r = this.real; + if (Double.isInfinite(r) || Double.isNaN(r)) { + // No integer primitive type can be infinite, and NaN never equals anything. + equal = false; + } else { + // Delegate the logic to PyFloat + PyFloat f = new PyFloat(r); + equal = (f.float___cmp__(other) == 0); + } + } else { + // No other primitive can have an imaginary part. + equal = false; + } + } else { + // other is not one of the types we know how to deal with. + return 2; + } + // Only "known" cases end here: translate to return code + return equal ? 1 : 0; } @Override @@ -261,11 +304,14 @@ @ExposedMethod(type = MethodType.BINARY, doc = BuiltinDocs.complex___ne___doc) final PyObject complex___ne__(PyObject other) { - if (!canCoerce(other)) { - return null; + switch (eq_helper(other)) { + case 0: + return Py.True; + case 1: + return Py.False; + default: + return null; } - PyComplex c = coerce(other); - return Py.newBoolean(real != c.real || imag != c.imag); } private PyObject unsupported_comparison(PyObject other) { @@ -415,22 +461,17 @@ } private final static PyObject _mul(PyComplex o1, PyComplex o2) { - if (Double.isNaN(o1.real) || - Double.isNaN(o1.imag) || - Double.isNaN(o2.real) || - Double.isNaN(o2.imag)) { + if (Double.isNaN(o1.real) || Double.isNaN(o1.imag) || Double.isNaN(o2.real) + || Double.isNaN(o2.imag)) { return NaN; } - if (Double.isInfinite(o1.real) || - Double.isInfinite(o1.imag) || - Double.isInfinite(o2.real) || - Double.isInfinite(o2.imag)) { + if (Double.isInfinite(o1.real) || Double.isInfinite(o1.imag) || Double.isInfinite(o2.real) + || Double.isInfinite(o2.imag)) { return Inf; } - return new PyComplex( - o1.real * o2.real - o1.imag * o2.imag, - o1.real * o2.imag + o1.imag * o2.real); + return new PyComplex(o1.real * o2.real - o1.imag * o2.imag, o1.real * o2.imag + o1.imag + * o2.real); } @Override -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 31 02:41:07 2014 From: jython-checkins at python.org (jeff.allen) Date: Wed, 31 Dec 2014 01:41:07 +0000 Subject: [Jython-checkins] =?utf-8?b?anl0aG9uOiBGaXggZmxvYXQuX19wb3dfXyB0?= =?utf-8?q?o_conform_to_spec=2E?= Message-ID: <20141231014103.125155.55889@psf.io> https://hg.python.org/jython/rev/5ad7dc15449d changeset: 7480:5ad7dc15449d user: Jeff Allen date: Thu Dec 18 22:10:58 2014 +0000 summary: Fix float.__pow__ to conform to spec. Fixes test failures in test_float, which can now revert to the CPython version. files: Lib/test/test_float.py | 1411 ------------------ src/org/python/core/PyFloat.java | 69 +- 2 files changed, 43 insertions(+), 1437 deletions(-) diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py deleted file mode 100644 --- a/Lib/test/test_float.py +++ /dev/null @@ -1,1411 +0,0 @@ - -import unittest, struct -import os -from test import test_support -import math -from math import isinf, isnan, copysign, ldexp -import operator -import random -import fractions -import sys - -INF = float("inf") -NAN = float("nan") - -have_getformat = hasattr(float, "__getformat__") -requires_getformat = unittest.skipUnless(have_getformat, - "requires __getformat__") -requires_setformat = unittest.skipUnless(hasattr(float, "__setformat__"), - "requires __setformat__") -# decorator for skipping tests on non-IEEE 754 platforms -requires_IEEE_754 = unittest.skipUnless(have_getformat and - float.__getformat__("double").startswith("IEEE"), - "test requires IEEE 754 doubles") - -#locate file with float format test values -test_dir = os.path.dirname(__file__) or os.curdir -format_testfile = os.path.join(test_dir, 'formatfloat_testcases.txt') - -class GeneralFloatCases(unittest.TestCase): - - def test_float(self): - self.assertEqual(float(3.14), 3.14) - self.assertEqual(float(314), 314.0) - self.assertEqual(float(314L), 314.0) - self.assertEqual(float(" 3.14 "), 3.14) - self.assertRaises(ValueError, float, " 0x3.1 ") - - self.assertRaises(ValueError, float, " -0x3.p-1 ") - self.assertRaises(ValueError, float, " +0x3.p-1 ") - - self.assertRaises(ValueError, float, "++3.14") - self.assertRaises(ValueError, float, "+-3.14") - self.assertRaises(ValueError, float, "-+3.14") - self.assertRaises(ValueError, float, "--3.14") - # check that we don't accept alternate exponent markers - self.assertRaises(ValueError, float, "-1.7d29") - self.assertRaises(ValueError, float, "3D-14") - if test_support.have_unicode: - self.assertEqual(float(unicode(" 3.14 ")), 3.14) - self.assertEqual(float(unicode(" \u0663.\u0661\u0664 ",'raw-unicode-escape')), 3.14) - - # extra long strings should no longer be a problem - # (in 2.6, long unicode inputs to float raised ValueError) - float('.' + '1'*1000) - float(unicode('.' + '1'*1000)) - - def check_conversion_to_int(self, x): - """Check that int(x) has the correct value and type, for a float x.""" - n = int(x) - if x >= 0.0: - # x >= 0 and n = int(x) ==> n <= x < n + 1 - self.assertLessEqual(n, x) - self.assertLess(x, n + 1) - else: - # x < 0 and n = int(x) ==> n >= x > n - 1 - self.assertGreaterEqual(n, x) - self.assertGreater(x, n - 1) - - # Result should be an int if within range, else a long. - if -sys.maxint-1 <= n <= sys.maxint: - self.assertEqual(type(n), int) - else: - self.assertEqual(type(n), long) - - # Double check. - self.assertEqual(type(int(n)), type(n)) - - def test_conversion_to_int(self): - # Check that floats within the range of an int convert to type - # int, not long. (issue #11144.) - boundary = float(sys.maxint + 1) - epsilon = 2**-sys.float_info.mant_dig * boundary - - # These 2 floats are either side of the positive int/long boundary on - # both 32-bit and 64-bit systems. - self.check_conversion_to_int(boundary - epsilon) - self.check_conversion_to_int(boundary) - - # These floats are either side of the negative long/int boundary on - # 64-bit systems... - self.check_conversion_to_int(-boundary - 2*epsilon) - self.check_conversion_to_int(-boundary) - - # ... and these ones are either side of the negative long/int - # boundary on 32-bit systems. - self.check_conversion_to_int(-boundary - 1.0) - self.check_conversion_to_int(-boundary - 1.0 + 2*epsilon) - - @test_support.run_with_locale('LC_NUMERIC', 'fr_FR', 'de_DE') - def test_float_with_comma(self): - # set locale to something that doesn't use '.' for the decimal point - # float must not accept the locale specific decimal point but - # it still has to accept the normal python syntax - import locale - if not locale.localeconv()['decimal_point'] == ',': - return - - self.assertEqual(float(" 3.14 "), 3.14) - self.assertEqual(float("+3.14 "), 3.14) - self.assertEqual(float("-3.14 "), -3.14) - self.assertEqual(float(".14 "), .14) - self.assertEqual(float("3. "), 3.0) - self.assertEqual(float("3.e3 "), 3000.0) - self.assertEqual(float("3.2e3 "), 3200.0) - self.assertEqual(float("2.5e-1 "), 0.25) - self.assertEqual(float("5e-1"), 0.5) - self.assertRaises(ValueError, float, " 3,14 ") - self.assertRaises(ValueError, float, " +3,14 ") - self.assertRaises(ValueError, float, " -3,14 ") - self.assertRaises(ValueError, float, " 0x3.1 ") - self.assertRaises(ValueError, float, " -0x3.p-1 ") - self.assertRaises(ValueError, float, " +0x3.p-1 ") - self.assertEqual(float(" 25.e-1 "), 2.5) - self.assertEqual(test_support.fcmp(float(" .25e-1 "), .025), 0) - - def test_floatconversion(self): - # Make sure that calls to __float__() work properly - class Foo0: - def __float__(self): - return 42. - - class Foo1(object): - def __float__(self): - return 42. - - class Foo2(float): - def __float__(self): - return 42. - - class Foo3(float): - def __new__(cls, value=0.): - return float.__new__(cls, 2*value) - - def __float__(self): - return self - - class Foo4(float): - def __float__(self): - return 42 - - # Issue 5759: __float__ not called on str subclasses (though it is on - # unicode subclasses). - class FooStr(str): - def __float__(self): - return float(str(self)) + 1 - - class FooUnicode(unicode): - def __float__(self): - return float(unicode(self)) + 1 - - self.assertAlmostEqual(float(Foo0()), 42.) - self.assertAlmostEqual(float(Foo1()), 42.) - self.assertAlmostEqual(float(Foo2()), 42.) - self.assertAlmostEqual(float(Foo3(21)), 42.) - self.assertRaises(TypeError, float, Foo4(42)) - self.assertAlmostEqual(float(FooUnicode('8')), 9.) - self.assertAlmostEqual(float(FooStr('8')), 9.) - - def test_is_integer(self): - self.assertFalse((1.1).is_integer()) - self.assertTrue((1.).is_integer()) - self.assertFalse(float("nan").is_integer()) - self.assertFalse(float("inf").is_integer()) - - def test_floatasratio(self): - for f, ratio in [ - (0.875, (7, 8)), - (-0.875, (-7, 8)), - (0.0, (0, 1)), - (11.5, (23, 2)), - ]: - self.assertEqual(f.as_integer_ratio(), ratio) - - for i in range(10000): - f = random.random() - f *= 10 ** random.randint(-100, 100) - n, d = f.as_integer_ratio() - self.assertEqual(float(n).__truediv__(d), f) - - R = fractions.Fraction - self.assertEqual(R(0, 1), - R(*float(0.0).as_integer_ratio())) - self.assertEqual(R(5, 2), - R(*float(2.5).as_integer_ratio())) - self.assertEqual(R(1, 2), - R(*float(0.5).as_integer_ratio())) - self.assertEqual(R(4728779608739021, 2251799813685248), - R(*float(2.1).as_integer_ratio())) - self.assertEqual(R(-4728779608739021, 2251799813685248), - R(*float(-2.1).as_integer_ratio())) - self.assertEqual(R(-2100, 1), - R(*float(-2100.0).as_integer_ratio())) - - self.assertRaises(OverflowError, float('inf').as_integer_ratio) - self.assertRaises(OverflowError, float('-inf').as_integer_ratio) - self.assertRaises(ValueError, float('nan').as_integer_ratio) - - def assertEqualAndEqualSign(self, a, b): - # fail unless a == b and a and b have the same sign bit; - # the only difference from assertEqual is that this test - # distinguishes -0.0 and 0.0. - self.assertEqual((a, copysign(1.0, a)), (b, copysign(1.0, b))) - - @requires_IEEE_754 - def test_float_mod(self): - # Check behaviour of % operator for IEEE 754 special cases. - # In particular, check signs of zeros. - mod = operator.mod - - self.assertEqualAndEqualSign(mod(-1.0, 1.0), 0.0) - self.assertEqualAndEqualSign(mod(-1e-100, 1.0), 1.0) - self.assertEqualAndEqualSign(mod(-0.0, 1.0), 0.0) - self.assertEqualAndEqualSign(mod(0.0, 1.0), 0.0) - self.assertEqualAndEqualSign(mod(1e-100, 1.0), 1e-100) - self.assertEqualAndEqualSign(mod(1.0, 1.0), 0.0) - - self.assertEqualAndEqualSign(mod(-1.0, -1.0), -0.0) - self.assertEqualAndEqualSign(mod(-1e-100, -1.0), -1e-100) - self.assertEqualAndEqualSign(mod(-0.0, -1.0), -0.0) - self.assertEqualAndEqualSign(mod(0.0, -1.0), -0.0) - self.assertEqualAndEqualSign(mod(1e-100, -1.0), -1.0) - self.assertEqualAndEqualSign(mod(1.0, -1.0), -0.0) - - @requires_IEEE_754 - def test_float_pow(self): - # test builtin pow and ** operator for IEEE 754 special cases. - # Special cases taken from section F.9.4.4 of the C99 specification - - for pow_op in pow, operator.pow: - # x**NAN is NAN for any x except 1 - self.assertTrue(isnan(pow_op(-INF, NAN))) - self.assertTrue(isnan(pow_op(-2.0, NAN))) - self.assertTrue(isnan(pow_op(-1.0, NAN))) - self.assertTrue(isnan(pow_op(-0.5, NAN))) - self.assertTrue(isnan(pow_op(-0.0, NAN))) - self.assertTrue(isnan(pow_op(0.0, NAN))) - self.assertTrue(isnan(pow_op(0.5, NAN))) - self.assertTrue(isnan(pow_op(2.0, NAN))) - self.assertTrue(isnan(pow_op(INF, NAN))) - self.assertTrue(isnan(pow_op(NAN, NAN))) - - # NAN**y is NAN for any y except +-0 - self.assertTrue(isnan(pow_op(NAN, -INF))) - self.assertTrue(isnan(pow_op(NAN, -2.0))) - self.assertTrue(isnan(pow_op(NAN, -1.0))) - self.assertTrue(isnan(pow_op(NAN, -0.5))) - self.assertTrue(isnan(pow_op(NAN, 0.5))) - self.assertTrue(isnan(pow_op(NAN, 1.0))) - self.assertTrue(isnan(pow_op(NAN, 2.0))) - self.assertTrue(isnan(pow_op(NAN, INF))) - - # (+-0)**y raises ZeroDivisionError for y a negative odd integer - self.assertRaises(ZeroDivisionError, pow_op, -0.0, -1.0) - self.assertRaises(ZeroDivisionError, pow_op, 0.0, -1.0) - - # (+-0)**y raises ZeroDivisionError for y finite and negative - # but not an odd integer - self.assertRaises(ZeroDivisionError, pow_op, -0.0, -2.0) - self.assertRaises(ZeroDivisionError, pow_op, -0.0, -0.5) - self.assertRaises(ZeroDivisionError, pow_op, 0.0, -2.0) - self.assertRaises(ZeroDivisionError, pow_op, 0.0, -0.5) - - # (+-0)**y is +-0 for y a positive odd integer - #FIXME: Jython fails this. - #self.assertEqualAndEqualSign(pow_op(-0.0, 1.0), -0.0) - self.assertEqualAndEqualSign(pow_op(0.0, 1.0), 0.0) - - # (+-0)**y is 0 for y finite and positive but not an odd integer - self.assertEqualAndEqualSign(pow_op(-0.0, 0.5), 0.0) - self.assertEqualAndEqualSign(pow_op(-0.0, 2.0), 0.0) - self.assertEqualAndEqualSign(pow_op(0.0, 0.5), 0.0) - self.assertEqualAndEqualSign(pow_op(0.0, 2.0), 0.0) - - # (-1)**+-inf is 1 - #FIXME: Jython fails these. - #self.assertEqualAndEqualSign(pow_op(-1.0, -INF), 1.0) - #self.assertEqualAndEqualSign(pow_op(-1.0, INF), 1.0) - - # 1**y is 1 for any y, even if y is an infinity or nan - #FIXME: Jython fails some of this. - #self.assertEqualAndEqualSign(pow_op(1.0, -INF), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, -2.0), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, -1.0), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, -0.5), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, 0.5), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, 1.0), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, 2.0), 1.0) - #FIXME: Jython fails some of this. - #self.assertEqualAndEqualSign(pow_op(1.0, INF), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, NAN), 1.0) - - # x**+-0 is 1 for any x, even if x is a zero, infinity, or nan - self.assertEqualAndEqualSign(pow_op(-INF, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-2.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-1.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-0.5, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-0.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(0.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(0.5, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(2.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(INF, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(NAN, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-INF, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-2.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-1.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-0.5, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-0.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(0.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(0.5, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(2.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(INF, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(NAN, -0.0), 1.0) - - # x**y raises ValueError for finite negative x and non-integral y - self.assertRaises(ValueError, pow_op, -2.0, -0.5) - self.assertRaises(ValueError, pow_op, -2.0, 0.5) - self.assertRaises(ValueError, pow_op, -1.0, -0.5) - self.assertRaises(ValueError, pow_op, -1.0, 0.5) - self.assertRaises(ValueError, pow_op, -0.5, -0.5) - self.assertRaises(ValueError, pow_op, -0.5, 0.5) - - # x**-INF is INF for abs(x) < 1 - self.assertEqualAndEqualSign(pow_op(-0.5, -INF), INF) - #FIXME: Jython fails these. - #self.assertEqualAndEqualSign(pow_op(-0.0, -INF), INF) - #self.assertEqualAndEqualSign(pow_op(0.0, -INF), INF) - self.assertEqualAndEqualSign(pow_op(0.5, -INF), INF) - - # x**-INF is 0 for abs(x) > 1 - self.assertEqualAndEqualSign(pow_op(-INF, -INF), 0.0) - self.assertEqualAndEqualSign(pow_op(-2.0, -INF), 0.0) - self.assertEqualAndEqualSign(pow_op(2.0, -INF), 0.0) - self.assertEqualAndEqualSign(pow_op(INF, -INF), 0.0) - - # x**INF is 0 for abs(x) < 1 - self.assertEqualAndEqualSign(pow_op(-0.5, INF), 0.0) - self.assertEqualAndEqualSign(pow_op(-0.0, INF), 0.0) - self.assertEqualAndEqualSign(pow_op(0.0, INF), 0.0) - self.assertEqualAndEqualSign(pow_op(0.5, INF), 0.0) - - # x**INF is INF for abs(x) > 1 - self.assertEqualAndEqualSign(pow_op(-INF, INF), INF) - self.assertEqualAndEqualSign(pow_op(-2.0, INF), INF) - self.assertEqualAndEqualSign(pow_op(2.0, INF), INF) - self.assertEqualAndEqualSign(pow_op(INF, INF), INF) - - # (-INF)**y is -0.0 for y a negative odd integer - self.assertEqualAndEqualSign(pow_op(-INF, -1.0), -0.0) - - # (-INF)**y is 0.0 for y negative but not an odd integer - #FIXME: Jython fails this. - #self.assertEqualAndEqualSign(pow_op(-INF, -0.5), 0.0) - self.assertEqualAndEqualSign(pow_op(-INF, -2.0), 0.0) - - # (-INF)**y is -INF for y a positive odd integer - self.assertEqualAndEqualSign(pow_op(-INF, 1.0), -INF) - - # (-INF)**y is INF for y positive but not an odd integer - #FIXME: Jython fails this. - #self.assertEqualAndEqualSign(pow_op(-INF, 0.5), INF) - self.assertEqualAndEqualSign(pow_op(-INF, 2.0), INF) - - # INF**y is INF for y positive - self.assertEqualAndEqualSign(pow_op(INF, 0.5), INF) - self.assertEqualAndEqualSign(pow_op(INF, 1.0), INF) - self.assertEqualAndEqualSign(pow_op(INF, 2.0), INF) - - # INF**y is 0.0 for y negative - self.assertEqualAndEqualSign(pow_op(INF, -2.0), 0.0) - self.assertEqualAndEqualSign(pow_op(INF, -1.0), 0.0) - self.assertEqualAndEqualSign(pow_op(INF, -0.5), 0.0) - - # basic checks not covered by the special cases above - self.assertEqualAndEqualSign(pow_op(-2.0, -2.0), 0.25) - self.assertEqualAndEqualSign(pow_op(-2.0, -1.0), -0.5) - self.assertEqualAndEqualSign(pow_op(-2.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-2.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-2.0, 1.0), -2.0) - self.assertEqualAndEqualSign(pow_op(-2.0, 2.0), 4.0) - self.assertEqualAndEqualSign(pow_op(-1.0, -2.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-1.0, -1.0), -1.0) - self.assertEqualAndEqualSign(pow_op(-1.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-1.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(-1.0, 1.0), -1.0) - self.assertEqualAndEqualSign(pow_op(-1.0, 2.0), 1.0) - self.assertEqualAndEqualSign(pow_op(2.0, -2.0), 0.25) - self.assertEqualAndEqualSign(pow_op(2.0, -1.0), 0.5) - self.assertEqualAndEqualSign(pow_op(2.0, -0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(2.0, 0.0), 1.0) - self.assertEqualAndEqualSign(pow_op(2.0, 1.0), 2.0) - self.assertEqualAndEqualSign(pow_op(2.0, 2.0), 4.0) - - # 1 ** large and -1 ** large; some libms apparently - # have problems with these - self.assertEqualAndEqualSign(pow_op(1.0, -1e100), 1.0) - self.assertEqualAndEqualSign(pow_op(1.0, 1e100), 1.0) - self.assertEqualAndEqualSign(pow_op(-1.0, -1e100), 1.0) - self.assertEqualAndEqualSign(pow_op(-1.0, 1e100), 1.0) - - # check sign for results that underflow to 0 - self.assertEqualAndEqualSign(pow_op(-2.0, -2000.0), 0.0) - self.assertRaises(ValueError, pow_op, -2.0, -2000.5) - self.assertEqualAndEqualSign(pow_op(-2.0, -2001.0), -0.0) - self.assertEqualAndEqualSign(pow_op(2.0, -2000.0), 0.0) - self.assertEqualAndEqualSign(pow_op(2.0, -2000.5), 0.0) - self.assertEqualAndEqualSign(pow_op(2.0, -2001.0), 0.0) - self.assertEqualAndEqualSign(pow_op(-0.5, 2000.0), 0.0) - self.assertRaises(ValueError, pow_op, -0.5, 2000.5) - self.assertEqualAndEqualSign(pow_op(-0.5, 2001.0), -0.0) - self.assertEqualAndEqualSign(pow_op(0.5, 2000.0), 0.0) - self.assertEqualAndEqualSign(pow_op(0.5, 2000.5), 0.0) - self.assertEqualAndEqualSign(pow_op(0.5, 2001.0), 0.0) - - # check we don't raise an exception for subnormal results, - # and validate signs. Tests currently disabled, since - # they fail on systems where a subnormal result from pow - # is flushed to zero (e.g. Debian/ia64.) - #self.assertTrue(0.0 < pow_op(0.5, 1048) < 1e-315) - #self.assertTrue(0.0 < pow_op(-0.5, 1048) < 1e-315) - #self.assertTrue(0.0 < pow_op(0.5, 1047) < 1e-315) - #self.assertTrue(0.0 > pow_op(-0.5, 1047) > -1e-315) - #self.assertTrue(0.0 < pow_op(2.0, -1048) < 1e-315) - #self.assertTrue(0.0 < pow_op(-2.0, -1048) < 1e-315) - #self.assertTrue(0.0 < pow_op(2.0, -1047) < 1e-315) - #self.assertTrue(0.0 > pow_op(-2.0, -1047) > -1e-315) - - - at requires_setformat -class FormatFunctionsTestCase(unittest.TestCase): - - def setUp(self): - self.save_formats = {'double':float.__getformat__('double'), - 'float':float.__getformat__('float')} - - def tearDown(self): - float.__setformat__('double', self.save_formats['double']) - float.__setformat__('float', self.save_formats['float']) - - def test_getformat(self): - self.assertIn(float.__getformat__('double'), - ['unknown', 'IEEE, big-endian', 'IEEE, little-endian']) - self.assertIn(float.__getformat__('float'), - ['unknown', 'IEEE, big-endian', 'IEEE, little-endian']) - self.assertRaises(ValueError, float.__getformat__, 'chicken') - self.assertRaises(TypeError, float.__getformat__, 1) - - def test_setformat(self): - for t in 'double', 'float': - float.__setformat__(t, 'unknown') - if self.save_formats[t] == 'IEEE, big-endian': - self.assertRaises(ValueError, float.__setformat__, - t, 'IEEE, little-endian') - elif self.save_formats[t] == 'IEEE, little-endian': - self.assertRaises(ValueError, float.__setformat__, - t, 'IEEE, big-endian') - else: - self.assertRaises(ValueError, float.__setformat__, - t, 'IEEE, big-endian') - self.assertRaises(ValueError, float.__setformat__, - t, 'IEEE, little-endian') - self.assertRaises(ValueError, float.__setformat__, - t, 'chicken') - self.assertRaises(ValueError, float.__setformat__, - 'chicken', 'unknown') - -BE_DOUBLE_INF = '\x7f\xf0\x00\x00\x00\x00\x00\x00' -LE_DOUBLE_INF = ''.join(reversed(BE_DOUBLE_INF)) -BE_DOUBLE_NAN = '\x7f\xf8\x00\x00\x00\x00\x00\x00' -LE_DOUBLE_NAN = ''.join(reversed(BE_DOUBLE_NAN)) - -BE_FLOAT_INF = '\x7f\x80\x00\x00' -LE_FLOAT_INF = ''.join(reversed(BE_FLOAT_INF)) -BE_FLOAT_NAN = '\x7f\xc0\x00\x00' -LE_FLOAT_NAN = ''.join(reversed(BE_FLOAT_NAN)) - -# on non-IEEE platforms, attempting to unpack a bit pattern -# representing an infinity or a NaN should raise an exception. - - at requires_setformat -class UnknownFormatTestCase(unittest.TestCase): - def setUp(self): - self.save_formats = {'double':float.__getformat__('double'), - 'float':float.__getformat__('float')} - float.__setformat__('double', 'unknown') - float.__setformat__('float', 'unknown') - - def tearDown(self): - float.__setformat__('double', self.save_formats['double']) - float.__setformat__('float', self.save_formats['float']) - - def test_double_specials_dont_unpack(self): - for fmt, data in [('>d', BE_DOUBLE_INF), - ('>d', BE_DOUBLE_NAN), - ('f', BE_FLOAT_INF), - ('>f', BE_FLOAT_NAN), - ('d', BE_DOUBLE_INF), - ('>d', BE_DOUBLE_NAN), - ('f', BE_FLOAT_INF), - ('>f', BE_FLOAT_NAN), - (''), str(x)) - self.assertEqual(format(x, '2'), str(x)) - - self.assertEqual(format(1.0, 'f'), '1.000000') - - self.assertEqual(format(-1.0, 'f'), '-1.000000') - - self.assertEqual(format( 1.0, ' f'), ' 1.000000') - self.assertEqual(format(-1.0, ' f'), '-1.000000') - self.assertEqual(format( 1.0, '+f'), '+1.000000') - self.assertEqual(format(-1.0, '+f'), '-1.000000') - - # % formatting - self.assertEqual(format(-1.0, '%'), '-100.000000%') - - # conversion to string should fail - self.assertRaises(ValueError, format, 3.0, "s") - - # other format specifiers shouldn't work on floats, - # in particular int specifiers - for format_spec in ([chr(x) for x in range(ord('a'), ord('z')+1)] + - [chr(x) for x in range(ord('A'), ord('Z')+1)]): - if not format_spec in 'eEfFgGn%': - self.assertRaises(ValueError, format, 0.0, format_spec) - self.assertRaises(ValueError, format, 1.0, format_spec) - self.assertRaises(ValueError, format, -1.0, format_spec) - self.assertRaises(ValueError, format, 1e100, format_spec) - self.assertRaises(ValueError, format, -1e100, format_spec) - self.assertRaises(ValueError, format, 1e-100, format_spec) - self.assertRaises(ValueError, format, -1e-100, format_spec) - - # issue 3382: 'f' and 'F' with inf's and nan's - self.assertEqual('{0:f}'.format(INF), 'inf') - self.assertEqual('{0:F}'.format(INF), 'INF') - self.assertEqual('{0:f}'.format(-INF), '-inf') - self.assertEqual('{0:F}'.format(-INF), '-INF') - self.assertEqual('{0:f}'.format(NAN), 'nan') - self.assertEqual('{0:F}'.format(NAN), 'NAN') - - @requires_IEEE_754 - def test_format_testfile(self): - with open(format_testfile) as testfile: - for line in open(format_testfile): - if line.startswith('--'): - continue - line = line.strip() - if not line: - continue - - lhs, rhs = map(str.strip, line.split('->')) - fmt, arg = lhs.split() - arg = float(arg) - self.assertEqual(fmt % arg, rhs) - if not math.isnan(arg) and copysign(1.0, arg) > 0.0: - self.assertEqual(fmt % -arg, '-' + rhs) - - def test_issue5864(self): - self.assertEqual(format(123.456, '.4'), '123.5') - self.assertEqual(format(1234.56, '.4'), '1.235e+03') - self.assertEqual(format(12345.6, '.4'), '1.235e+04') - -class ReprTestCase(unittest.TestCase): - def test_repr(self): - floats_file = open(os.path.join(os.path.split(__file__)[0], - 'floating_points.txt')) - for line in floats_file: - line = line.strip() - if not line or line.startswith('#'): - continue - v = eval(line) - self.assertEqual(v, eval(repr(v))) - floats_file.close() - - @unittest.skipUnless(getattr(sys, 'float_repr_style', '') == 'short', - "applies only when using short float repr style") - def test_short_repr(self): - # test short float repr introduced in Python 3.1. One aspect - # of this repr is that we get some degree of str -> float -> - # str roundtripping. In particular, for any numeric string - # containing 15 or fewer significant digits, those exact same - # digits (modulo trailing zeros) should appear in the output. - # No more repr(0.03) -> "0.029999999999999999"! - - test_strings = [ - # output always includes *either* a decimal point and at - # least one digit after that point, or an exponent. - '0.0', - '1.0', - '0.01', - '0.02', - '0.03', - '0.04', - '0.05', - '1.23456789', - '10.0', - '100.0', - # values >= 1e16 get an exponent... - '1000000000000000.0', - '9999999999999990.0', - '1e+16', - '1e+17', - # ... and so do values < 1e-4 - '0.001', - '0.001001', - '0.00010000000000001', - '0.0001', - '9.999999999999e-05', - '1e-05', - # values designed to provoke failure if the FPU rounding - # precision isn't set correctly - '8.72293771110361e+25', - '7.47005307342313e+26', - '2.86438000439698e+28', - '8.89142905246179e+28', - '3.08578087079232e+35', - ] - - for s in test_strings: - negs = '-'+s - self.assertEqual(s, repr(float(s))) - self.assertEqual(negs, repr(float(negs))) - - - at requires_IEEE_754 -class RoundTestCase(unittest.TestCase): - def test_second_argument_type(self): - # any type with an __index__ method should be permitted as - # a second argument - self.assertAlmostEqual(round(12.34, True), 12.3) - - class MyIndex(object): - def __index__(self): return 4 - self.assertAlmostEqual(round(-0.123456, MyIndex()), -0.1235) - # but floats should be illegal - self.assertRaises(TypeError, round, 3.14159, 2.0) - - def test_inf_nan(self): - # rounding an infinity or nan returns the same number; - # (in py3k, rounding an infinity or nan raises an error, - # since the result can't be represented as a long). - self.assertEqual(round(INF), INF) - self.assertEqual(round(-INF), -INF) - self.assertTrue(math.isnan(round(NAN))) - for n in range(-5, 5): - self.assertEqual(round(INF, n), INF) - self.assertEqual(round(-INF, n), -INF) - self.assertTrue(math.isnan(round(NAN, n))) - - self.assertRaises(TypeError, round, INF, 0.0) - self.assertRaises(TypeError, round, -INF, 1.0) - self.assertRaises(TypeError, round, NAN, "ceci n'est pas un entier") - self.assertRaises(TypeError, round, -0.0, 1j) - - def test_large_n(self): - for n in [324, 325, 400, 2**31-1, 2**31, 2**32, 2**100]: - self.assertEqual(round(123.456, n), 123.456) - self.assertEqual(round(-123.456, n), -123.456) - self.assertEqual(round(1e300, n), 1e300) - self.assertEqual(round(1e-320, n), 1e-320) - self.assertEqual(round(1e150, 300), 1e150) - self.assertEqual(round(1e300, 307), 1e300) - self.assertEqual(round(-3.1415, 308), -3.1415) - self.assertEqual(round(1e150, 309), 1e150) - self.assertEqual(round(1.4e-315, 315), 1e-315) - - def test_small_n(self): - for n in [-308, -309, -400, 1-2**31, -2**31, -2**31-1, -2**100]: - self.assertEqual(round(123.456, n), 0.0) - self.assertEqual(round(-123.456, n), -0.0) - self.assertEqual(round(1e300, n), 0.0) - self.assertEqual(round(1e-320, n), 0.0) - - def test_overflow(self): - self.assertRaises(OverflowError, round, 1.6e308, -308) - self.assertRaises(OverflowError, round, -1.7e308, -308) - - @unittest.skipUnless(getattr(sys, 'float_repr_style', '') == 'short', - "test applies only when using short float repr style") - def test_previous_round_bugs(self): - # particular cases that have occurred in bug reports - self.assertEqual(round(562949953421312.5, 1), - 562949953421312.5) - self.assertEqual(round(56294995342131.5, 3), - 56294995342131.5) - - @unittest.skipUnless(getattr(sys, 'float_repr_style', '') == 'short', - "test applies only when using short float repr style") - def test_halfway_cases(self): - # Halfway cases need special attention, since the current - # implementation has to deal with them specially. Note that - # 2.x rounds halfway values up (i.e., away from zero) while - # 3.x does round-half-to-even. - self.assertAlmostEqual(round(0.125, 2), 0.13) - self.assertAlmostEqual(round(0.375, 2), 0.38) - self.assertAlmostEqual(round(0.625, 2), 0.63) - self.assertAlmostEqual(round(0.875, 2), 0.88) - self.assertAlmostEqual(round(-0.125, 2), -0.13) - self.assertAlmostEqual(round(-0.375, 2), -0.38) - self.assertAlmostEqual(round(-0.625, 2), -0.63) - self.assertAlmostEqual(round(-0.875, 2), -0.88) - - self.assertAlmostEqual(round(0.25, 1), 0.3) - self.assertAlmostEqual(round(0.75, 1), 0.8) - self.assertAlmostEqual(round(-0.25, 1), -0.3) - self.assertAlmostEqual(round(-0.75, 1), -0.8) - - self.assertEqual(round(-6.5, 0), -7.0) - self.assertEqual(round(-5.5, 0), -6.0) - self.assertEqual(round(-1.5, 0), -2.0) - self.assertEqual(round(-0.5, 0), -1.0) - self.assertEqual(round(0.5, 0), 1.0) - self.assertEqual(round(1.5, 0), 2.0) - self.assertEqual(round(2.5, 0), 3.0) - self.assertEqual(round(3.5, 0), 4.0) - self.assertEqual(round(4.5, 0), 5.0) - self.assertEqual(round(5.5, 0), 6.0) - self.assertEqual(round(6.5, 0), 7.0) - - # same but without an explicit second argument; in 3.x these - # will give integers - self.assertEqual(round(-6.5), -7.0) - self.assertEqual(round(-5.5), -6.0) - self.assertEqual(round(-1.5), -2.0) - self.assertEqual(round(-0.5), -1.0) - self.assertEqual(round(0.5), 1.0) - self.assertEqual(round(1.5), 2.0) - self.assertEqual(round(2.5), 3.0) - self.assertEqual(round(3.5), 4.0) - self.assertEqual(round(4.5), 5.0) - self.assertEqual(round(5.5), 6.0) - self.assertEqual(round(6.5), 7.0) - - self.assertEqual(round(-25.0, -1), -30.0) - self.assertEqual(round(-15.0, -1), -20.0) - self.assertEqual(round(-5.0, -1), -10.0) - self.assertEqual(round(5.0, -1), 10.0) - self.assertEqual(round(15.0, -1), 20.0) - self.assertEqual(round(25.0, -1), 30.0) - self.assertEqual(round(35.0, -1), 40.0) - self.assertEqual(round(45.0, -1), 50.0) - self.assertEqual(round(55.0, -1), 60.0) - self.assertEqual(round(65.0, -1), 70.0) - self.assertEqual(round(75.0, -1), 80.0) - self.assertEqual(round(85.0, -1), 90.0) - self.assertEqual(round(95.0, -1), 100.0) - self.assertEqual(round(12325.0, -1), 12330.0) - - self.assertEqual(round(350.0, -2), 400.0) - self.assertEqual(round(450.0, -2), 500.0) - - self.assertAlmostEqual(round(0.5e21, -21), 1e21) - self.assertAlmostEqual(round(1.5e21, -21), 2e21) - self.assertAlmostEqual(round(2.5e21, -21), 3e21) - self.assertAlmostEqual(round(5.5e21, -21), 6e21) - self.assertAlmostEqual(round(8.5e21, -21), 9e21) - - self.assertAlmostEqual(round(-1.5e22, -22), -2e22) - self.assertAlmostEqual(round(-0.5e22, -22), -1e22) - self.assertAlmostEqual(round(0.5e22, -22), 1e22) - self.assertAlmostEqual(round(1.5e22, -22), 2e22) - - - @requires_IEEE_754 - def test_format_specials(self): - # Test formatting of nans and infs. - - def test(fmt, value, expected): - # Test with both % and format(). - self.assertEqual(fmt % value, expected, fmt) - if not '#' in fmt: - # Until issue 7094 is implemented, format() for floats doesn't - # support '#' formatting - fmt = fmt[1:] # strip off the % - self.assertEqual(format(value, fmt), expected, fmt) - - for fmt in ['%e', '%f', '%g', '%.0e', '%.6f', '%.20g', - '%#e', '%#f', '%#g', '%#.20e', '%#.15f', '%#.3g']: - pfmt = '%+' + fmt[1:] - sfmt = '% ' + fmt[1:] - test(fmt, INF, 'inf') - test(fmt, -INF, '-inf') - test(fmt, NAN, 'nan') - test(fmt, -NAN, 'nan') - # When asking for a sign, it's always provided. nans are - # always positive. - test(pfmt, INF, '+inf') - test(pfmt, -INF, '-inf') - test(pfmt, NAN, '+nan') - test(pfmt, -NAN, '+nan') - # When using ' ' for a sign code, only infs can be negative. - # Others have a space. - test(sfmt, INF, ' inf') - test(sfmt, -INF, '-inf') - test(sfmt, NAN, ' nan') - test(sfmt, -NAN, ' nan') - - -# Beginning with Python 2.6 float has cross platform compatible -# ways to create and represent inf and nan -class InfNanTest(unittest.TestCase): - def test_inf_from_str(self): - self.assertTrue(isinf(float("inf"))) - self.assertTrue(isinf(float("+inf"))) - self.assertTrue(isinf(float("-inf"))) - self.assertTrue(isinf(float("infinity"))) - self.assertTrue(isinf(float("+infinity"))) - self.assertTrue(isinf(float("-infinity"))) - - self.assertEqual(repr(float("inf")), "inf") - self.assertEqual(repr(float("+inf")), "inf") - self.assertEqual(repr(float("-inf")), "-inf") - self.assertEqual(repr(float("infinity")), "inf") - self.assertEqual(repr(float("+infinity")), "inf") - self.assertEqual(repr(float("-infinity")), "-inf") - - self.assertEqual(repr(float("INF")), "inf") - self.assertEqual(repr(float("+Inf")), "inf") - self.assertEqual(repr(float("-iNF")), "-inf") - self.assertEqual(repr(float("Infinity")), "inf") - self.assertEqual(repr(float("+iNfInItY")), "inf") - self.assertEqual(repr(float("-INFINITY")), "-inf") - - self.assertEqual(str(float("inf")), "inf") - self.assertEqual(str(float("+inf")), "inf") - self.assertEqual(str(float("-inf")), "-inf") - self.assertEqual(str(float("infinity")), "inf") - self.assertEqual(str(float("+infinity")), "inf") - self.assertEqual(str(float("-infinity")), "-inf") - - self.assertRaises(ValueError, float, "info") - self.assertRaises(ValueError, float, "+info") - self.assertRaises(ValueError, float, "-info") - self.assertRaises(ValueError, float, "in") - self.assertRaises(ValueError, float, "+in") - self.assertRaises(ValueError, float, "-in") - self.assertRaises(ValueError, float, "infinit") - self.assertRaises(ValueError, float, "+Infin") - self.assertRaises(ValueError, float, "-INFI") - self.assertRaises(ValueError, float, "infinitys") - - def test_inf_as_str(self): - self.assertEqual(repr(1e300 * 1e300), "inf") - self.assertEqual(repr(-1e300 * 1e300), "-inf") - - self.assertEqual(str(1e300 * 1e300), "inf") - self.assertEqual(str(-1e300 * 1e300), "-inf") - - def test_nan_from_str(self): - self.assertTrue(isnan(float("nan"))) - self.assertTrue(isnan(float("+nan"))) - self.assertTrue(isnan(float("-nan"))) - - self.assertEqual(repr(float("nan")), "nan") - self.assertEqual(repr(float("+nan")), "nan") - self.assertEqual(repr(float("-nan")), "nan") - - self.assertEqual(repr(float("NAN")), "nan") - self.assertEqual(repr(float("+NAn")), "nan") - self.assertEqual(repr(float("-NaN")), "nan") - - self.assertEqual(str(float("nan")), "nan") - self.assertEqual(str(float("+nan")), "nan") - self.assertEqual(str(float("-nan")), "nan") - - self.assertRaises(ValueError, float, "nana") - self.assertRaises(ValueError, float, "+nana") - self.assertRaises(ValueError, float, "-nana") - self.assertRaises(ValueError, float, "na") - self.assertRaises(ValueError, float, "+na") - self.assertRaises(ValueError, float, "-na") - - def test_nan_as_str(self): - self.assertEqual(repr(1e300 * 1e300 * 0), "nan") - self.assertEqual(repr(-1e300 * 1e300 * 0), "nan") - - self.assertEqual(str(1e300 * 1e300 * 0), "nan") - self.assertEqual(str(-1e300 * 1e300 * 0), "nan") - - def notest_float_nan(self): - self.assertTrue(NAN.is_nan()) - self.assertFalse(INF.is_nan()) - self.assertFalse((0.).is_nan()) - - def notest_float_inf(self): - self.assertTrue(INF.is_inf()) - self.assertFalse(NAN.is_inf()) - self.assertFalse((0.).is_inf()) - - def test_hash_inf(self): - # the actual values here should be regarded as an - # implementation detail, but they need to be - # identical to those used in the Decimal module. - self.assertEqual(hash(float('inf')), 314159) - self.assertEqual(hash(float('-inf')), -271828) - self.assertEqual(hash(float('nan')), 0) - - -fromHex = float.fromhex -toHex = float.hex -class HexFloatTestCase(unittest.TestCase): - MAX = fromHex('0x.fffffffffffff8p+1024') # max normal - MIN = fromHex('0x1p-1022') # min normal - TINY = fromHex('0x0.0000000000001p-1022') # min subnormal - EPS = fromHex('0x0.0000000000001p0') # diff between 1.0 and next float up - - def identical(self, x, y): - # check that floats x and y are identical, or that both - # are NaNs - if isnan(x) or isnan(y): - if isnan(x) == isnan(y): - return - elif x == y and (x != 0.0 or copysign(1.0, x) == copysign(1.0, y)): - return - self.fail('%r not identical to %r' % (x, y)) - - def test_ends(self): - self.identical(self.MIN, ldexp(1.0, -1022)) - self.identical(self.TINY, ldexp(1.0, -1074)) - self.identical(self.EPS, ldexp(1.0, -52)) - self.identical(self.MAX, 2.*(ldexp(1.0, 1023) - ldexp(1.0, 970))) - - def test_invalid_inputs(self): - invalid_inputs = [ - 'infi', # misspelt infinities and nans - '-Infinit', - '++inf', - '-+Inf', - '--nan', - '+-NaN', - 'snan', - 'NaNs', - 'nna', - 'an', - 'nf', - 'nfinity', - 'inity', - 'iinity', - '0xnan', - '', - ' ', - 'x1.0p0', - '0xX1.0p0', - '+ 0x1.0p0', # internal whitespace - '- 0x1.0p0', - '0 x1.0p0', - '0x 1.0p0', - '0x1 2.0p0', - '+0x1 .0p0', - '0x1. 0p0', - '-0x1.0 1p0', - '-0x1.0 p0', - '+0x1.0p +0', - '0x1.0p -0', - '0x1.0p 0', - '+0x1.0p+ 0', - '-0x1.0p- 0', - '++0x1.0p-0', # double signs - '--0x1.0p0', - '+-0x1.0p+0', - '-+0x1.0p0', - '0x1.0p++0', - '+0x1.0p+-0', - '-0x1.0p-+0', - '0x1.0p--0', - '0x1.0.p0', - '0x.p0', # no hex digits before or after point - '0x1,p0', # wrong decimal point character - '0x1pa', - u'0x1p\uff10', # fullwidth Unicode digits - u'\uff10x1p0', - u'0x\uff11p0', - u'0x1.\uff10p0', - '0x1p0 \n 0x2p0', - '0x1p0\0 0x1p0', # embedded null byte is not end of string - ] - for x in invalid_inputs: - try: - result = fromHex(x) - except ValueError: - pass - else: - self.fail('Expected float.fromhex(%r) to raise ValueError; ' - 'got %r instead' % (x, result)) - - - def test_whitespace(self): - value_pairs = [ - ('inf', INF), - ('-Infinity', -INF), - ('nan', NAN), - ('1.0', 1.0), - ('-0x.2', -0.125), - ('-0.0', -0.0) - ] - whitespace = [ - '', - ' ', - '\t', - '\n', - '\n \t', - '\f', - '\v', - '\r' - ] - for inp, expected in value_pairs: - for lead in whitespace: - for trail in whitespace: - got = fromHex(lead + inp + trail) - self.identical(got, expected) - - - def test_from_hex(self): - MIN = self.MIN; - MAX = self.MAX; - TINY = self.TINY; - EPS = self.EPS; - - # two spellings of infinity, with optional signs; case-insensitive - self.identical(fromHex('inf'), INF) - self.identical(fromHex('+Inf'), INF) - self.identical(fromHex('-INF'), -INF) - self.identical(fromHex('iNf'), INF) - self.identical(fromHex('Infinity'), INF) - self.identical(fromHex('+INFINITY'), INF) - self.identical(fromHex('-infinity'), -INF) - self.identical(fromHex('-iNFiNitY'), -INF) - - # nans with optional sign; case insensitive - self.identical(fromHex('nan'), NAN) - self.identical(fromHex('+NaN'), NAN) - self.identical(fromHex('-NaN'), NAN) - self.identical(fromHex('-nAN'), NAN) - - # variations in input format - self.identical(fromHex('1'), 1.0) - self.identical(fromHex('+1'), 1.0) - self.identical(fromHex('1.'), 1.0) - self.identical(fromHex('1.0'), 1.0) - self.identical(fromHex('1.0p0'), 1.0) - self.identical(fromHex('01'), 1.0) - self.identical(fromHex('01.'), 1.0) - self.identical(fromHex('0x1'), 1.0) - self.identical(fromHex('0x1.'), 1.0) - self.identical(fromHex('0x1.0'), 1.0) - self.identical(fromHex('+0x1.0'), 1.0) - self.identical(fromHex('0x1p0'), 1.0) - self.identical(fromHex('0X1p0'), 1.0) - self.identical(fromHex('0X1P0'), 1.0) - self.identical(fromHex('0x1P0'), 1.0) - self.identical(fromHex('0x1.p0'), 1.0) - self.identical(fromHex('0x1.0p0'), 1.0) - self.identical(fromHex('0x.1p4'), 1.0) - self.identical(fromHex('0x.1p04'), 1.0) - self.identical(fromHex('0x.1p004'), 1.0) - self.identical(fromHex('0x1p+0'), 1.0) - self.identical(fromHex('0x1P-0'), 1.0) - self.identical(fromHex('+0x1p0'), 1.0) - self.identical(fromHex('0x01p0'), 1.0) - self.identical(fromHex('0x1p00'), 1.0) - self.identical(fromHex(u'0x1p0'), 1.0) - self.identical(fromHex(' 0x1p0 '), 1.0) - self.identical(fromHex('\n 0x1p0'), 1.0) - self.identical(fromHex('0x1p0 \t'), 1.0) - self.identical(fromHex('0xap0'), 10.0) - self.identical(fromHex('0xAp0'), 10.0) - self.identical(fromHex('0xaP0'), 10.0) - self.identical(fromHex('0xAP0'), 10.0) - self.identical(fromHex('0xbep0'), 190.0) - self.identical(fromHex('0xBep0'), 190.0) - self.identical(fromHex('0xbEp0'), 190.0) - self.identical(fromHex('0XBE0P-4'), 190.0) - self.identical(fromHex('0xBEp0'), 190.0) - self.identical(fromHex('0xB.Ep4'), 190.0) - self.identical(fromHex('0x.BEp8'), 190.0) - self.identical(fromHex('0x.0BEp12'), 190.0) - - # moving the point around - pi = fromHex('0x1.921fb54442d18p1') - self.identical(fromHex('0x.006487ed5110b46p11'), pi) - self.identical(fromHex('0x.00c90fdaa22168cp10'), pi) - self.identical(fromHex('0x.01921fb54442d18p9'), pi) - self.identical(fromHex('0x.03243f6a8885a3p8'), pi) - self.identical(fromHex('0x.06487ed5110b46p7'), pi) - self.identical(fromHex('0x.0c90fdaa22168cp6'), pi) - self.identical(fromHex('0x.1921fb54442d18p5'), pi) - self.identical(fromHex('0x.3243f6a8885a3p4'), pi) - self.identical(fromHex('0x.6487ed5110b46p3'), pi) - self.identical(fromHex('0x.c90fdaa22168cp2'), pi) - self.identical(fromHex('0x1.921fb54442d18p1'), pi) - self.identical(fromHex('0x3.243f6a8885a3p0'), pi) - self.identical(fromHex('0x6.487ed5110b46p-1'), pi) - self.identical(fromHex('0xc.90fdaa22168cp-2'), pi) - self.identical(fromHex('0x19.21fb54442d18p-3'), pi) - self.identical(fromHex('0x32.43f6a8885a3p-4'), pi) - self.identical(fromHex('0x64.87ed5110b46p-5'), pi) - self.identical(fromHex('0xc9.0fdaa22168cp-6'), pi) - self.identical(fromHex('0x192.1fb54442d18p-7'), pi) - self.identical(fromHex('0x324.3f6a8885a3p-8'), pi) - self.identical(fromHex('0x648.7ed5110b46p-9'), pi) - self.identical(fromHex('0xc90.fdaa22168cp-10'), pi) - self.identical(fromHex('0x1921.fb54442d18p-11'), pi) - # ... - self.identical(fromHex('0x1921fb54442d1.8p-47'), pi) - self.identical(fromHex('0x3243f6a8885a3p-48'), pi) - self.identical(fromHex('0x6487ed5110b46p-49'), pi) - self.identical(fromHex('0xc90fdaa22168cp-50'), pi) - self.identical(fromHex('0x1921fb54442d18p-51'), pi) - self.identical(fromHex('0x3243f6a8885a30p-52'), pi) - self.identical(fromHex('0x6487ed5110b460p-53'), pi) - self.identical(fromHex('0xc90fdaa22168c0p-54'), pi) - self.identical(fromHex('0x1921fb54442d180p-55'), pi) - - - # results that should overflow... - self.assertRaises(OverflowError, fromHex, '-0x1p1024') - self.assertRaises(OverflowError, fromHex, '0x1p+1025') - self.assertRaises(OverflowError, fromHex, '+0X1p1030') - self.assertRaises(OverflowError, fromHex, '-0x1p+1100') - self.assertRaises(OverflowError, fromHex, '0X1p123456789123456789') - self.assertRaises(OverflowError, fromHex, '+0X.8p+1025') - self.assertRaises(OverflowError, fromHex, '+0x0.8p1025') - self.assertRaises(OverflowError, fromHex, '-0x0.4p1026') - self.assertRaises(OverflowError, fromHex, '0X2p+1023') - self.assertRaises(OverflowError, fromHex, '0x2.p1023') - self.assertRaises(OverflowError, fromHex, '-0x2.0p+1023') - self.assertRaises(OverflowError, fromHex, '+0X4p+1022') - self.assertRaises(OverflowError, fromHex, '0x1.ffffffffffffffp+1023') - self.assertRaises(OverflowError, fromHex, '-0X1.fffffffffffff9p1023') - self.assertRaises(OverflowError, fromHex, '0X1.fffffffffffff8p1023') - self.assertRaises(OverflowError, fromHex, '+0x3.fffffffffffffp1022') - self.assertRaises(OverflowError, fromHex, '0x3fffffffffffffp+970') - self.assertRaises(OverflowError, fromHex, '0x10000000000000000p960') - self.assertRaises(OverflowError, fromHex, '-0Xffffffffffffffffp960') - - # ...and those that round to +-max float - self.identical(fromHex('+0x1.fffffffffffffp+1023'), MAX) - self.identical(fromHex('-0X1.fffffffffffff7p1023'), -MAX) - self.identical(fromHex('0X1.fffffffffffff7fffffffffffffp1023'), MAX) - - # zeros - self.identical(fromHex('0x0p0'), 0.0) - self.identical(fromHex('0x0p1000'), 0.0) - self.identical(fromHex('-0x0p1023'), -0.0) - self.identical(fromHex('0X0p1024'), 0.0) - self.identical(fromHex('-0x0p1025'), -0.0) - self.identical(fromHex('0X0p2000'), 0.0) - self.identical(fromHex('0x0p123456789123456789'), 0.0) - self.identical(fromHex('-0X0p-0'), -0.0) - self.identical(fromHex('-0X0p-1000'), -0.0) - self.identical(fromHex('0x0p-1023'), 0.0) - self.identical(fromHex('-0X0p-1024'), -0.0) - self.identical(fromHex('-0x0p-1025'), -0.0) - self.identical(fromHex('-0x0p-1072'), -0.0) - self.identical(fromHex('0X0p-1073'), 0.0) - self.identical(fromHex('-0x0p-1074'), -0.0) - self.identical(fromHex('0x0p-1075'), 0.0) - self.identical(fromHex('0X0p-1076'), 0.0) - self.identical(fromHex('-0X0p-2000'), -0.0) - self.identical(fromHex('-0x0p-123456789123456789'), -0.0) - - # values that should underflow to 0 - self.identical(fromHex('0X1p-1075'), 0.0) - self.identical(fromHex('-0X1p-1075'), -0.0) - self.identical(fromHex('-0x1p-123456789123456789'), -0.0) - self.identical(fromHex('0x1.00000000000000001p-1075'), TINY) - self.identical(fromHex('-0x1.1p-1075'), -TINY) - self.identical(fromHex('0x1.fffffffffffffffffp-1075'), TINY) - - # check round-half-even is working correctly near 0 ... - self.identical(fromHex('0x1p-1076'), 0.0) - self.identical(fromHex('0X2p-1076'), 0.0) - self.identical(fromHex('0X3p-1076'), TINY) - self.identical(fromHex('0x4p-1076'), TINY) - self.identical(fromHex('0X5p-1076'), TINY) - self.identical(fromHex('0X6p-1076'), 2*TINY) - self.identical(fromHex('0x7p-1076'), 2*TINY) - self.identical(fromHex('0X8p-1076'), 2*TINY) - self.identical(fromHex('0X9p-1076'), 2*TINY) - self.identical(fromHex('0xap-1076'), 2*TINY) - self.identical(fromHex('0Xbp-1076'), 3*TINY) - self.identical(fromHex('0xcp-1076'), 3*TINY) - self.identical(fromHex('0Xdp-1076'), 3*TINY) - self.identical(fromHex('0Xep-1076'), 4*TINY) - self.identical(fromHex('0xfp-1076'), 4*TINY) - self.identical(fromHex('0x10p-1076'), 4*TINY) - self.identical(fromHex('-0x1p-1076'), -0.0) - self.identical(fromHex('-0X2p-1076'), -0.0) - self.identical(fromHex('-0x3p-1076'), -TINY) - self.identical(fromHex('-0X4p-1076'), -TINY) - self.identical(fromHex('-0x5p-1076'), -TINY) - self.identical(fromHex('-0x6p-1076'), -2*TINY) - self.identical(fromHex('-0X7p-1076'), -2*TINY) - self.identical(fromHex('-0X8p-1076'), -2*TINY) - self.identical(fromHex('-0X9p-1076'), -2*TINY) - self.identical(fromHex('-0Xap-1076'), -2*TINY) - self.identical(fromHex('-0xbp-1076'), -3*TINY) - self.identical(fromHex('-0xcp-1076'), -3*TINY) - self.identical(fromHex('-0Xdp-1076'), -3*TINY) - self.identical(fromHex('-0xep-1076'), -4*TINY) - self.identical(fromHex('-0Xfp-1076'), -4*TINY) - self.identical(fromHex('-0X10p-1076'), -4*TINY) - - # ... and near MIN ... - self.identical(fromHex('0x0.ffffffffffffd6p-1022'), MIN-3*TINY) - self.identical(fromHex('0x0.ffffffffffffd8p-1022'), MIN-2*TINY) - self.identical(fromHex('0x0.ffffffffffffdap-1022'), MIN-2*TINY) - self.identical(fromHex('0x0.ffffffffffffdcp-1022'), MIN-2*TINY) - self.identical(fromHex('0x0.ffffffffffffdep-1022'), MIN-2*TINY) - self.identical(fromHex('0x0.ffffffffffffe0p-1022'), MIN-2*TINY) - self.identical(fromHex('0x0.ffffffffffffe2p-1022'), MIN-2*TINY) - self.identical(fromHex('0x0.ffffffffffffe4p-1022'), MIN-2*TINY) - self.identical(fromHex('0x0.ffffffffffffe6p-1022'), MIN-2*TINY) - self.identical(fromHex('0x0.ffffffffffffe8p-1022'), MIN-2*TINY) - self.identical(fromHex('0x0.ffffffffffffeap-1022'), MIN-TINY) - self.identical(fromHex('0x0.ffffffffffffecp-1022'), MIN-TINY) - self.identical(fromHex('0x0.ffffffffffffeep-1022'), MIN-TINY) - self.identical(fromHex('0x0.fffffffffffff0p-1022'), MIN-TINY) - self.identical(fromHex('0x0.fffffffffffff2p-1022'), MIN-TINY) - self.identical(fromHex('0x0.fffffffffffff4p-1022'), MIN-TINY) - self.identical(fromHex('0x0.fffffffffffff6p-1022'), MIN-TINY) - self.identical(fromHex('0x0.fffffffffffff8p-1022'), MIN) - self.identical(fromHex('0x0.fffffffffffffap-1022'), MIN) - self.identical(fromHex('0x0.fffffffffffffcp-1022'), MIN) - self.identical(fromHex('0x0.fffffffffffffep-1022'), MIN) - self.identical(fromHex('0x1.00000000000000p-1022'), MIN) - self.identical(fromHex('0x1.00000000000002p-1022'), MIN) - self.identical(fromHex('0x1.00000000000004p-1022'), MIN) - self.identical(fromHex('0x1.00000000000006p-1022'), MIN) - self.identical(fromHex('0x1.00000000000008p-1022'), MIN) - self.identical(fromHex('0x1.0000000000000ap-1022'), MIN+TINY) - self.identical(fromHex('0x1.0000000000000cp-1022'), MIN+TINY) - self.identical(fromHex('0x1.0000000000000ep-1022'), MIN+TINY) - self.identical(fromHex('0x1.00000000000010p-1022'), MIN+TINY) - self.identical(fromHex('0x1.00000000000012p-1022'), MIN+TINY) - self.identical(fromHex('0x1.00000000000014p-1022'), MIN+TINY) - self.identical(fromHex('0x1.00000000000016p-1022'), MIN+TINY) - self.identical(fromHex('0x1.00000000000018p-1022'), MIN+2*TINY) - - # ... and near 1.0. - self.identical(fromHex('0x0.fffffffffffff0p0'), 1.0-EPS) - self.identical(fromHex('0x0.fffffffffffff1p0'), 1.0-EPS) - self.identical(fromHex('0X0.fffffffffffff2p0'), 1.0-EPS) - self.identical(fromHex('0x0.fffffffffffff3p0'), 1.0-EPS) - self.identical(fromHex('0X0.fffffffffffff4p0'), 1.0-EPS) - self.identical(fromHex('0X0.fffffffffffff5p0'), 1.0-EPS/2) - self.identical(fromHex('0X0.fffffffffffff6p0'), 1.0-EPS/2) - self.identical(fromHex('0x0.fffffffffffff7p0'), 1.0-EPS/2) - self.identical(fromHex('0x0.fffffffffffff8p0'), 1.0-EPS/2) - self.identical(fromHex('0X0.fffffffffffff9p0'), 1.0-EPS/2) - self.identical(fromHex('0X0.fffffffffffffap0'), 1.0-EPS/2) - self.identical(fromHex('0x0.fffffffffffffbp0'), 1.0-EPS/2) - self.identical(fromHex('0X0.fffffffffffffcp0'), 1.0) - self.identical(fromHex('0x0.fffffffffffffdp0'), 1.0) - self.identical(fromHex('0X0.fffffffffffffep0'), 1.0) - self.identical(fromHex('0x0.ffffffffffffffp0'), 1.0) - self.identical(fromHex('0X1.00000000000000p0'), 1.0) - self.identical(fromHex('0X1.00000000000001p0'), 1.0) - self.identical(fromHex('0x1.00000000000002p0'), 1.0) - self.identical(fromHex('0X1.00000000000003p0'), 1.0) - self.identical(fromHex('0x1.00000000000004p0'), 1.0) - self.identical(fromHex('0X1.00000000000005p0'), 1.0) - self.identical(fromHex('0X1.00000000000006p0'), 1.0) - self.identical(fromHex('0X1.00000000000007p0'), 1.0) - self.identical(fromHex('0x1.00000000000007ffffffffffffffffffffp0'), - 1.0) - self.identical(fromHex('0x1.00000000000008p0'), 1.0) - self.identical(fromHex('0x1.00000000000008000000000000000001p0'), - 1+EPS) - self.identical(fromHex('0X1.00000000000009p0'), 1.0+EPS) - self.identical(fromHex('0x1.0000000000000ap0'), 1.0+EPS) - self.identical(fromHex('0x1.0000000000000bp0'), 1.0+EPS) - self.identical(fromHex('0X1.0000000000000cp0'), 1.0+EPS) - self.identical(fromHex('0x1.0000000000000dp0'), 1.0+EPS) - self.identical(fromHex('0x1.0000000000000ep0'), 1.0+EPS) - self.identical(fromHex('0X1.0000000000000fp0'), 1.0+EPS) - self.identical(fromHex('0x1.00000000000010p0'), 1.0+EPS) - self.identical(fromHex('0X1.00000000000011p0'), 1.0+EPS) - self.identical(fromHex('0x1.00000000000012p0'), 1.0+EPS) - self.identical(fromHex('0X1.00000000000013p0'), 1.0+EPS) - self.identical(fromHex('0X1.00000000000014p0'), 1.0+EPS) - self.identical(fromHex('0x1.00000000000015p0'), 1.0+EPS) - self.identical(fromHex('0x1.00000000000016p0'), 1.0+EPS) - self.identical(fromHex('0X1.00000000000017p0'), 1.0+EPS) - self.identical(fromHex('0x1.00000000000017ffffffffffffffffffffp0'), - 1.0+EPS) - self.identical(fromHex('0x1.00000000000018p0'), 1.0+2*EPS) - self.identical(fromHex('0X1.00000000000018000000000000000001p0'), - 1.0+2*EPS) - self.identical(fromHex('0x1.00000000000019p0'), 1.0+2*EPS) - self.identical(fromHex('0X1.0000000000001ap0'), 1.0+2*EPS) - self.identical(fromHex('0X1.0000000000001bp0'), 1.0+2*EPS) - self.identical(fromHex('0x1.0000000000001cp0'), 1.0+2*EPS) - self.identical(fromHex('0x1.0000000000001dp0'), 1.0+2*EPS) - self.identical(fromHex('0x1.0000000000001ep0'), 1.0+2*EPS) - self.identical(fromHex('0X1.0000000000001fp0'), 1.0+2*EPS) - self.identical(fromHex('0x1.00000000000020p0'), 1.0+2*EPS) - - def test_roundtrip(self): - def roundtrip(x): - return fromHex(toHex(x)) - - for x in [NAN, INF, self.MAX, self.MIN, self.MIN-self.TINY, self.TINY, 0.0]: - self.identical(x, roundtrip(x)) - self.identical(-x, roundtrip(-x)) - - # fromHex(toHex(x)) should exactly recover x, for any non-NaN float x. - import random - for i in xrange(10000): - e = random.randrange(-1200, 1200) - m = random.random() - s = random.choice([1.0, -1.0]) - try: - x = s*ldexp(m, e) - except OverflowError: - pass - else: - self.identical(x, fromHex(toHex(x))) - - -def test_main(): - test_support.run_unittest( - GeneralFloatCases, - FormatFunctionsTestCase, - UnknownFormatTestCase, - IEEEFormatTestCase, - ReprTestCase, - RoundTestCase, - InfNanTest, - HexFloatTestCase, - ) - -if __name__ == '__main__': - test_main() diff --git a/src/org/python/core/PyFloat.java b/src/org/python/core/PyFloat.java --- a/src/org/python/core/PyFloat.java +++ b/src/org/python/core/PyFloat.java @@ -731,9 +731,9 @@ return null; } else if (modulo != null) { throw Py.TypeError("pow() 3rd argument not allowed unless all arguments are integers"); + } else { + return _pow(getValue(), coerce(right)); } - - return _pow(getValue(), coerce(right), modulo); } @ExposedMethod(type = MethodType.BINARY, doc = BuiltinDocs.float___rpow___doc) @@ -745,43 +745,60 @@ public PyObject __rpow__(PyObject left) { if (!canCoerce(left)) { return null; + } else { + return _pow(coerce(left), getValue()); } - - return _pow(coerce(left), getValue(), null); } - private static PyFloat _pow(double value, double iw, PyObject modulo) { - // Rely completely on Java's pow function - if (iw == 0) { - if (modulo != null) { - return new PyFloat(modulo(1.0, coerce(modulo))); - } + private static PyFloat _pow(double v, double w) { + /* + * This code was translated from the CPython implementation at v2.7.8 by progressively + * removing cases that could be delegated to Java. Jython differs from CPython in that where + * C pow() overflows, Java pow() returns inf (observed on Windows). This is not subject to + * regression tests, so we take it as an allowable platform dependency. All other + * differences in Java Math.pow() are trapped below and Python behaviour is enforced. + */ + if (w == 0) { + // v**0 is 1, even 0**0 return new PyFloat(1.0); - } - if (value == 0.0) { - if (iw < 0.0) { - throw Py.ZeroDivisionError("0.0 cannot be raised to a negative power"); - } else if (Double.isNaN(iw)) { - return new PyFloat(Double.NaN); - } - return new PyFloat(0); - } + } else if (Double.isNaN(v)) { + // nan**w = nan, unless w == 0 + return new PyFloat(Double.NaN); - if (Double.isNaN(iw)) { - if (value == 1.0) { + } else if (Double.isNaN(w)) { + // v**nan = nan, unless v == 1; 1**nan = 1 + if (v == 1.0) { return new PyFloat(1.0); } else { return new PyFloat(Double.NaN); } + + } else if (Double.isInfinite(w)) { + /* + * In Java Math pow(1,inf) = pow(-1,inf) = pow(1,-inf) = pow(-1,-inf) = nan, but in + * Python they are all 1. + */ + if (v == 1.0 || v == -1.0) { + return new PyFloat(1.0); + } + + } else if (v == 0.0) { + // 0**w is an error if w is negative. + if (w < 0.0) { + throw Py.ZeroDivisionError("0.0 cannot be raised to a negative power"); + } + + } else if (!Double.isInfinite(v) && v < 0.0) { + if (w != Math.floor(w)) { + throw Py.ValueError("negative number cannot be raised to a fractional power"); + } + } - if (value < 0 && iw != Math.floor(iw)) { - throw Py.ValueError("negative number cannot be raised to a fractional power"); - } + // In all cases not caught above we can entrust the calculation to Java + return new PyFloat(Math.pow(v, w)); - double ret = Math.pow(value, iw); - return new PyFloat(modulo == null ? ret : modulo(ret, coerce(modulo))); } @Override -- Repository URL: https://hg.python.org/jython From jython-checkins at python.org Wed Dec 31 02:41:07 2014 From: jython-checkins at python.org (jeff.allen) Date: Wed, 31 Dec 2014 01:41:07 +0000 Subject: [Jython-checkins] =?utf-8?q?jython=3A_Javadoc_and_tests_relevant_?= =?utf-8?q?to_Unicode_in_PythonInterpreter=2E?= Message-ID: <20141231014102.120067.53596@psf.io> https://hg.python.org/jython/rev/20e60a04d605 changeset: 7477:20e60a04d605 parent: 7432:7e2e9537565f user: Jeff Allen date: Mon Dec 15 22:45:08 2014 +0000 summary: Javadoc and tests relevant to Unicode in PythonInterpreter. Makes trivial (format only) changes to Java code and adds Javadoc clarifying the current behaviour of PythonInterpreter and its subclasses with respect to Unicode programs and Reader/Writer streams used as sys.stdin/out/err. Tests are added to exercise the behaviour, including 3 skipped tests where features may be lacking. files: Lib/test/test_pythoninterpreter_jy.py | 261 +++++++++- src/org/python/util/InteractiveConsole.java | 59 +- src/org/python/util/InteractiveInterpreter.java | 137 +++- src/org/python/util/PythonInterpreter.java | 195 ++++--- 4 files changed, 507 insertions(+), 145 deletions(-) diff --git a/Lib/test/test_pythoninterpreter_jy.py b/Lib/test/test_pythoninterpreter_jy.py --- a/Lib/test/test_pythoninterpreter_jy.py +++ b/Lib/test/test_pythoninterpreter_jy.py @@ -1,35 +1,56 @@ # -*- coding: utf-8 -*- -import java.io.StringWriter +import java.io import sys import traceback +import types import unittest import test.test_support +from org.python.core.util import StringUtil +from org.python.core import PyFile +from _codecs import encode +from sun.awt.image import BufImgVolatileSurfaceManager - -def exec_code_in_pi(function, out, err, locals=None): +def exec_code_in_pi(source, inp=None, out=None, err=None, locals=None): """Runs code in a separate context: (thread, PySystemState, PythonInterpreter)""" - def function_context(): + def execution_context(): from org.python.core import Py from org.python.util import PythonInterpreter from org.python.core import PySystemState ps = PySystemState() pi = PythonInterpreter({}, ps) - if locals: - pi.setLocals(locals) - pi.setOut(out) - pi.setErr(err) + if locals is not None: pi.setLocals(locals) + if inp is not None: pi.setIn(inp) + if out is not None: pi.setOut(out) + if err is not None: pi.setErr(err) try: - pi.exec(function.func_code) + if isinstance(source, types.FunctionType): + # A function wrapping a compiled code block + pi.exec(source.func_code) + + elif isinstance(source, java.io.InputStream): + # A byte-oriented file-like input stream + pi.execfile(source) + + elif isinstance(source, java.io.Reader): + # A character-oriented file-like input stream + code = pi.compile(source) + pi.exec(code) + + else: + # A str or unicode (see UnicodeSourceTest) + pi.exec(source) + except: + print print '-'*60 traceback.print_exc(file=sys.stdout) print '-'*60 import threading - context = threading.Thread(target=function_context) + context = threading.Thread(target=execution_context) context.start() context.join() @@ -54,17 +75,17 @@ print x out = java.io.StringWriter() err = java.io.StringWriter() - exec_code_in_pi(f, out, err, {'text': source_text}) + exec_code_in_pi(f, None, out, err, {'text': source_text}) output_text = out.toString().splitlines() - for source, output in zip(source_text, output_text): - self.assertEquals(source, output) + for output, source in zip(output_text, source_text): + self.assertEquals(output, source) def test_pi_out(self): def f(): print 42 out = java.io.StringWriter() err = java.io.StringWriter() - exec_code_in_pi(f, out, err) + exec_code_in_pi(f, None, out, err) self.assertEquals(u"42\n", out.toString()) def test_more_output(self): @@ -73,16 +94,224 @@ print "*" * i out = java.io.StringWriter() err = java.io.StringWriter() - exec_code_in_pi(f, out, err) + exec_code_in_pi(f, None, out, err) output = out.toString().splitlines() for i, line in enumerate(output): self.assertEquals(line, u'*' * i) self.assertEquals(42, len(output)) +class UnicodeSourceTest(unittest.TestCase): + + # When the core PythonInterpreter is embedded in a Java program + # it may be supplied as Unicode source as a string or via streams. + + def do_test(self, source, ref_out=u'', ref_var={}, inp=None): + out = java.io.StringWriter() + err = java.io.StringWriter() + var = {} + if inp is not None: + if isinstance(inp, bytes): + inp = java.io.ByteArrayInputStream(StringUtil.toBytes(inp)) + elif isinstance(inp, unicode): + inp = java.io.StringReader(inp) + + exec_code_in_pi(source, inp, out, err, var) + self.assertEquals(ref_var, var) + self.assertEquals(ref_out, out.toString()) + + def test_ascii_str(self): + # Program written in bytes with ascii range only + self.do_test('a = 42\nprint a', u'42\n', {'a':42}) + + def test_latin_str(self): + # Program written in bytes with codes above 127 + self.do_test('a = "caf\xe9"\nprint a', u'caf\xe9\n', {'a':'caf\xe9'}) + + def test_ascii_unicode(self): + # Program written in Unicode with ascii range only + self.do_test(u'a = "hello"\nprint a', u'hello\n', {'a':'hello'}) + + def test_latin_unicode(self): + # Program written in Unicode with codes above 127 + self.do_test(u'a = "caf\xe9"\nprint a', u'caf\xe9\n', {'a':'caf\xe9'}) + + @unittest.skip("PythonInterpreter.exec(String) does not distinguish str/unicode") + def test_bmp_unicode(self): + # Program written in Unicode with codes above 255 + a = u"???? ?????" + prog = u'a = u"{:s}"\nprint repr(a)'.format(a) + # Submit via exec(unicode) + self.do_test(prog, + u'{}\n'.format(repr(a)), + {'a': a}) + + def test_bmp_utf8stream(self): + # Program written in Unicode with codes above 255 + a = u"???? ?????" + prog = u'a = u"{:s}"\nprint repr(a)'.format(a) + # Program as bytes with declared encoding for execfile(InputStream) + progbytes = '# coding: utf-8\n' + prog.encode('utf-8') + stream = java.io.ByteArrayInputStream(StringUtil.toBytes(progbytes)) + self.do_test(stream, + u'{}\n'.format(repr(a)), + {'a': a}) + + def test_bmp_reader(self): + # Program written in Unicode with codes above 255 + a = u"???? ?????" + prog = u'a = u"{:s}"\nprint repr(a)'.format(a) + # Program as character stream for exec(compile(Reader)) + self.do_test(java.io.StringReader(prog), + u'{}\n'.format(repr(a)), + {'a': a}) + +def unicode_lines(): + input_lines = [ + u'Some text', + u'un caf? cr?me', + u"?????", + u"????", + ] + input_text = u'\n'.join(input_lines) + return input_lines, input_text + + +class InterpreterSetInTest(unittest.TestCase): + + # When the core PythonInterpreter is embedded in a Java program it + # may be connected through SetIn to a Unicode or byte stream. + # However, the Unicode Reader interface narrows the data to bytes + # in a way that mangles anything beyond Latin-1. These tests + # illustrate that preparatory to a possible solution, in which the + # encoding is specified to the PythonInterpreter and appears as + # sys.stdin.encoding etc. for use by the application (and libraries). + + @staticmethod + def do_read(): + import sys + buf = bytearray() + while True: + c = sys.stdin.read(1) + if not c: break + buf.append(c) + # A defined encoding ought to be advertised in sys.stdin.encoding + enc = getattr(sys.stdin, 'encoding', None) + # In the test, allow an override via local variables + enc = locals().get('encoding', enc) + if enc: + result = buf.decode(enc) # unicode + else: + result = bytes(buf) + + def test_pi_bytes_read(self): + # Test read() with pi.setIn(PyFile(InputStream)) + input_lines, input_text = unicode_lines() + input_bytes = input_text.encode('utf-8') + inp = java.io.ByteArrayInputStream(input_bytes) + var = dict() + exec_code_in_pi(InterpreterSetInTest.do_read, inp, locals=var) + result = var['result'] + self.assertEquals(result, input_bytes) + self.assertEquals(type(result), type(input_bytes)) + + @unittest.skip("Jython treats characters from a Reader as bytes.") + # Has no unicode encoding and fails to build PyString for codes > 255 + def test_pi_unicode_read(self): + # Test read() with pi.setIn(Reader) + input_lines, input_text = unicode_lines() + inp = java.io.StringReader(input_text) + var = dict() + exec_code_in_pi(InterpreterSetInTest.do_read, inp, locals=var) + result = var['result'] + self.assertEquals(result, input_text) + self.assertEquals(type(result), type(input_text)) + + def test_pi_encoding_read(self): + # Test read() with pi.setIn(PyFile(InputStream)) and defined encoding + input_lines, input_text = unicode_lines() + input_bytes = input_text.encode('utf-8') + inp = java.io.ByteArrayInputStream(input_bytes) + var = {'encoding': 'utf-8'} + exec_code_in_pi(InterpreterSetInTest.do_read, inp, locals=var) + result = var['result'] + self.assertEquals(result, input_text) + self.assertEquals(type(result), type(input_text)) + + @staticmethod + def do_readline(): + import sys + # A defined encoding ought to be advertised in sys.stdin.encoding + enc = getattr(sys.stdin, 'encoding', None) + # In the test, allow an override via local variables + enc = locals().get('encoding', enc) + result = list() + while True: + line = sys.stdin.readline() + if not line: break + if enc: line = line.decode(enc) # unicode + result.append(line.rstrip('\n')) + + def test_pi_bytes_readline(self): + # Test readline() with pi.setIn(PyFile(InputStream)) + input_lines, input_text = unicode_lines() + input_bytes = input_text.encode('utf-8') + inp = java.io.ByteArrayInputStream(input_bytes) + var = dict() + exec_code_in_pi(InterpreterSetInTest.do_readline, inp, locals=var) + for output, source in zip(var['result'], input_lines): + source = source.encode('utf-8') + self.assertEquals(output, source) + self.assertEquals(type(output), type(source)) + + @unittest.skip("Jython treats characters from a Reader as bytes.") + # Has no unicode encoding and fails to build PyString for codes > 255 + def test_pi_unicode_readline(self): + # Test readline() with pi.setIn(Reader) + input_lines, input_text = unicode_lines() + inp = java.io.StringReader(input_text) + var = dict() + exec_code_in_pi(InterpreterSetInTest.do_readline, inp, locals=var) + for output, source in zip(var['result'], input_lines): + self.assertEquals(output, source) + self.assertEquals(type(output), type(source)) + + def test_pi_encoding_readline(self): + # Test readline() pi.setIn(PyFile(InputStream)) and defined encoding + input_lines, input_text = unicode_lines() + input_bytes = input_text.encode('utf-8') + inp = java.io.ByteArrayInputStream(input_bytes) + var = {'encoding': 'utf-8'} + exec_code_in_pi(InterpreterSetInTest.do_readline, inp, locals=var) + for output, source in zip(var['result'], input_lines): + self.assertEquals(output, source) + self.assertEquals(type(output), type(source)) + + @staticmethod + def do_readinto(): + import sys + buf = bytearray(1024) + n = sys.stdin.readinto(buf) + result = buf[:n] + + def test_pi_bytes_readinto(self): + # Test readinto() with pi.setIn(PyFile(InputStream)) + input_lines, input_text = unicode_lines() + input_bytes = input_text.encode('utf-8') + inp = java.io.ByteArrayInputStream(input_bytes) + var = dict() + exec_code_in_pi(InterpreterSetInTest.do_readinto, inp, locals=var) + self.assertEquals(var['result'], input_bytes) + + # There is no readinto() with pi.setIn(Reader) + def test_main(): - test.test_support.run_unittest(InterpreterTest) + test.test_support.run_unittest( + InterpreterTest, + UnicodeSourceTest, + InterpreterSetInTest, + ) if __name__ == "__main__": test_main() diff --git a/src/org/python/util/InteractiveConsole.java b/src/org/python/util/InteractiveConsole.java --- a/src/org/python/util/InteractiveConsole.java +++ b/src/org/python/util/InteractiveConsole.java @@ -8,6 +8,16 @@ import org.python.core.PySystemState; import org.python.core.__builtin__; +/** + * This class provides the read, execute, print loop needed by a Python console; it is not actually + * a console itself. The primary capability is the {@link #interact()} method, which repeatedly + * calls {@link #raw_input(PyObject)}, and hence {@link __builtin__#raw_input(PyObject)}, in order + * to get lines, and {@link #push(String)} them into the interpreter. The built-in + * raw_input() method prompts on sys.stdout and reads from + * sys.stdin, the standard console. These may be redirected using + * {@link #setOut(java.io.OutputStream)} and {@link #setIn(java.io.InputStream)}, as may also + * sys.stderr. + */ // Based on CPython-1.5.2's code module public class InteractiveConsole extends InteractiveInterpreter { @@ -15,21 +25,43 @@ public String filename; + /** + * Construct an interactive console, which will "run" when {@link #interact()} is called. The + * name of the console (e.g. in error messages) will be {@value #CONSOLE_FILENAME}. + */ public InteractiveConsole() { this(null, CONSOLE_FILENAME); } + /** + * Construct an interactive console, which will "run" when {@link #interact()} is called. The + * name of the console (e.g. in error messages) will be {@value #CONSOLE_FILENAME}. + * + * @param locals dictionary to use, or if null, a new empty one will be created + */ public InteractiveConsole(PyObject locals) { this(locals, CONSOLE_FILENAME); } + /** + * Construct an interactive console, which will "run" when {@link #interact()} is called. + * + * @param locals dictionary to use, or if null, a new empty one will be created + * @param filename name with which to label this console input (e.g. in error messages). + */ public InteractiveConsole(PyObject locals, String filename) { this(locals, filename, false); } /** - * @param replaceRawInput if true, we hook this Class's raw_input into the built-ins table so - * that clients like cmd.Cmd use it. + * Full-feature constructor for an interactive console, which will "run" when + * {@link #interact()} is called. This version allows the caller to replace the built-in + * raw_input() methods with {@link #raw_input(PyObject)} and + * {@link #raw_input(PyObject, PyObject)}, which may be overridden in a sub-class. + * + * @param locals dictionary to use, or if null, a new empty one will be created + * @param filename name with which to label this console input + * @param replaceRawInput if true, hook this class's raw_input into the built-ins. */ public InteractiveConsole(PyObject locals, String filename, boolean replaceRawInput) { super(locals); @@ -52,19 +84,32 @@ } /** - * Closely emulate the interactive Python console. - * - * The optional banner argument specifies the banner to print before the first interaction; by - * default it prints "Jython on ". + * Operate a Python console, as in {@link #interact(String, PyObject)}, on the standard input. + * The standard input may have been redirected by {@link #setIn(java.io.InputStream)} or its + * variants. The banner (printed before first input) is obtained by calling + * {@link #getDefaultBanner()}. */ public void interact() { interact(getDefaultBanner(), null); } + /** + * Returns the banner to print before the first interaction: "Jython on ". + * + * @return the banner. + */ public static String getDefaultBanner() { - return String.format("Jython %s on %s", PySystemState.version, Py.getSystemState().platform); + return String + .format("Jython %s on %s", PySystemState.version, Py.getSystemState().platform); } + /** + * Operate a Python console by repeatedly calling {@link #raw_input(PyObject, PyObject)} and + * interpreting the lines read. An end of file causes the method to return. + * + * @param banner to print before accepting input, or if null, no banner. + * @param file from which to read commands, or if null, read the console. + */ public void interact(String banner, PyObject file) { if (banner != null) { write(banner); diff --git a/src/org/python/util/InteractiveInterpreter.java b/src/org/python/util/InteractiveInterpreter.java --- a/src/org/python/util/InteractiveInterpreter.java +++ b/src/org/python/util/InteractiveInterpreter.java @@ -1,51 +1,96 @@ // Copyright (c) Corporation for National Research Initiatives package org.python.util; + import org.python.core.*; +/** + * This class provides the interface for compiling and running code that supports an interactive + * interpreter. + */ // Based on CPython-1.5.2's code module +public class InteractiveInterpreter extends PythonInterpreter { -public class InteractiveInterpreter extends PythonInterpreter { + /** + * Construct an InteractiveInterpreter with all default characteristics: default state (from + * {@link Py#getSystemState()}), and a new empty dictionary of local variables. + * */ public InteractiveInterpreter() { this(null); } + + /** + * Construct an InteractiveInterpreter with state (from {@link Py#getSystemState()}), and the + * specified dictionary of local variables. + * + * @param locals dictionary to use, or if null, a new empty one will be created + */ public InteractiveInterpreter(PyObject locals) { this(locals, null); - } - public InteractiveInterpreter(PyObject locals, PySystemState systemState) { - super(locals, systemState); - } /** - * Compile and run some source in the interpreter. + * Construct an InteractiveInterpreter with, and system state the specified dictionary of local + * variables. * - * Arguments are as for compile_command(). + * @param locals dictionary to use, or if null, a new empty one will be created + * @param systemState interpreter state, or if null use {@link Py#getSystemState()} + */ + public InteractiveInterpreter(PyObject locals, PySystemState systemState) { + super(locals, systemState); + } + + /** + * Compile and run some source in the interpreter, in the mode {@link CompileMode#single} which + * is used for incremental compilation at the interactive console, known as "". * - * One several things can happen: - * - * 1) The input is incorrect; compile_command() raised an exception - * (SyntaxError or OverflowError). A syntax traceback will be printed - * by calling the showsyntaxerror() method. - * - * 2) The input is incomplete, and more input is required; - * compile_command() returned None. Nothing happens. - * - * 3) The input is complete; compile_command() returned a code object. - * The code is executed by calling self.runcode() (which also handles - * run-time exceptions, except for SystemExit). - * - * The return value is 1 in case 2, 0 in the other cases (unless an - * exception is raised). The return value can be used to decide - * whether to use sys.ps1 or sys.ps2 to prompt the next line. - **/ + * @param source Python code + * @return true to indicate a partial statement was entered + */ public boolean runsource(String source) { return runsource(source, "", CompileMode.single); } + /** + * Compile and run some source in the interpreter, in the mode {@link CompileMode#single} which + * is used for incremental compilation at the interactive console. + * + * @param source Python code + * @param filename name with which to label this console input (e.g. in error messages). + * @return true to indicate a partial statement was entered + */ public boolean runsource(String source, String filename) { return runsource(source, filename, CompileMode.single); } + /** + * Compile and run some source in the interpreter, according to the {@link CompileMode} given. + * This method supports incremental compilation and interpretation through the return value, + * where true signifies that more input is expected in order to complete the Python + * statement. An interpreter can use this to decide whether to use sys.ps1 (" + * >>> ") or sys.ps2 ("... ") to prompt the next line. + * The arguments are the same as the mandatory ones in the Python compile() + * command. + *

+ * One the following can happen: + *

    + *
  1. The input is incorrect; compilation raised an exception (SyntaxError or OverflowError). A + * syntax traceback will be printed by calling {@link #showexception(PyException)}. Return is + * false.
  2. + * + *
  3. The input is incomplete, and more input is required; compilation returned no code. + * Nothing happens. Return is true.
  4. + * + *
  5. The input is complete; compilation returned a code object. The code is executed by + * calling {@link #runcode(PyObject)} (which also handles run-time exceptions, except for + * SystemExit). Return is false.
  6. + *
+ * + * @param source Python code + * @param filename name with which to label this console input (e.g. in error messages). + * @param kind of compilation required: {@link CompileMode#eval}, {@link CompileMode#exec} or + * {@link CompileMode#single} + * @return true to indicate a partial statement was provided + */ public boolean runsource(String source, String filename, CompileMode kind) { PyObject code; try { @@ -64,23 +109,20 @@ } } // Case 2 - if (code == Py.None) + if (code == Py.None) { return true; + } // Case 3 runcode(code); return false; } /** - * execute a code object. - * - * When an exception occurs, self.showtraceback() is called to display - * a traceback. All exceptions are caught except SystemExit, which is - * reraised. - * - * A note about KeyboardInterrupt: this exception may occur elsewhere - * in this code, and may not always be caught. The caller should be - * prepared to deal with it. + * Execute a code object. When an exception occurs, {@link #showexception(PyException)} is + * called to display a stack trace, except in the case of SystemExit, which is re-raised. + *

+ * A note about KeyboardInterrupt: this exception may occur elsewhere in this code, and may not + * always be caught. The caller should be prepared to deal with it. **/ // Make this run in another thread somehow???? @@ -88,7 +130,9 @@ try { exec(code); } catch (PyException exc) { - if (exc.match(Py.SystemExit)) throw exc; + if (exc.match(Py.SystemExit)) { + throw exc; + } showexception(exc); } } @@ -96,7 +140,7 @@ public void showexception(PyException exc) { // Should probably add code to handle skipping top stack frames // somehow... - Py.printException(exc); + Py.printException(exc); } public void write(String data) { @@ -104,48 +148,55 @@ } public StringBuffer buffer = new StringBuffer(); - public String filename=""; + public String filename = ""; public void resetbuffer() { buffer.setLength(0); } - /** Pause the current code, sneak an exception raiser into - * sys.trace_func, and then continue the code hoping that Jython will - * get control to do the break; + /** + * Pause the current code, sneak an exception raiser into sys.trace_func, and then continue the + * code hoping that Jython will get control to do the break; **/ public void interrupt(ThreadState ts) { TraceFunction breaker = new BreakTraceFunction(); TraceFunction oldTrace = ts.tracefunc; ts.tracefunc = breaker; - if (ts.frame != null) + if (ts.frame != null) { ts.frame.tracefunc = breaker; + } ts.tracefunc = oldTrace; - //ts.thread.join(); + // ts.thread.join(); } } + class BreakTraceFunction extends TraceFunction { + private void doBreak() { throw new Error("Python interrupt"); - //Thread.currentThread().interrupt(); + // Thread.currentThread().interrupt(); } + @Override public TraceFunction traceCall(PyFrame frame) { doBreak(); return null; } + @Override public TraceFunction traceReturn(PyFrame frame, PyObject ret) { doBreak(); return null; } + @Override public TraceFunction traceLine(PyFrame frame, int line) { doBreak(); return null; } + @Override public TraceFunction traceException(PyFrame frame, PyException exc) { doBreak(); return null; diff --git a/src/org/python/util/PythonInterpreter.java b/src/org/python/util/PythonInterpreter.java --- a/src/org/python/util/PythonInterpreter.java +++ b/src/org/python/util/PythonInterpreter.java @@ -25,8 +25,8 @@ import org.python.core.PyFileReader; /** - * The PythonInterpreter class is a standard wrapper for a Jython interpreter - * for embedding in a Java application. + * The PythonInterpreter class is a standard wrapper for a Jython interpreter for embedding in a + * Java application. */ public class PythonInterpreter implements AutoCloseable, Closeable { @@ -37,6 +37,7 @@ protected final boolean useThreadLocalState; protected static ThreadLocal threadLocals = new ThreadLocal() { + @Override protected Object[] initialValue() { return new Object[1]; @@ -48,24 +49,18 @@ private volatile boolean closed = false; /** - * Initializes the Jython runtime. This should only be called - * once, before any other Python objects (including - * PythonInterpreter) are created. + * Initializes the Jython runtime. This should only be called once, before any other Python + * objects (including PythonInterpreter) are created. * - * @param preProperties - * A set of properties. Typically - * System.getProperties() is used. preProperties - * override properties from the registry file. - * @param postProperties - * Another set of properties. Values like python.home, - * python.path and all other values from the registry - * files can be added to this property - * set. postProperties override system properties and - * registry properties. - * @param argv - * Command line arguments, assigned to sys.argv. + * @param preProperties A set of properties. Typically System.getProperties() is used. + * preProperties override properties from the registry file. + * @param postProperties Another set of properties. Values like python.home, python.path and all + * other values from the registry files can be added to this property set. + * postProperties override system properties and registry properties. + * @param argv Command line arguments, assigned to sys.argv. */ - public static void initialize(Properties preProperties, Properties postProperties, String[] argv) { + public static void + initialize(Properties preProperties, Properties postProperties, String[] argv) { PySystemState.initialize(preProperties, postProperties, argv); } @@ -77,13 +72,10 @@ } /** - * Creates a new interpreter with the ability to maintain a - * separate local namespace for each thread (set by invoking - * setLocals()). + * Creates a new interpreter with the ability to maintain a separate local namespace for each + * thread (set by invoking setLocals()). * - * @param dict - * a Python mapping object (e.g., a dictionary) for use - * as the default namespace + * @param dict a Python mapping object (e.g., a dictionary) for use as the default namespace */ public static PythonInterpreter threadLocalStateInterpreter(PyObject dict) { return new PythonInterpreter(dict, null, true); @@ -92,9 +84,7 @@ /** * Creates a new interpreter with a specified local namespace. * - * @param dict - * a Python mapping object (e.g., a dictionary) for use - * as the namespace + * @param dict a Python mapping object (e.g., a dictionary) for use as the namespace */ public PythonInterpreter(PyObject dict) { this(dict, null); @@ -104,14 +94,16 @@ this(dict, systemState, false); } - protected PythonInterpreter(PyObject dict, PySystemState systemState, boolean useThreadLocalState) { + protected PythonInterpreter(PyObject dict, PySystemState systemState, + boolean useThreadLocalState) { if (dict == null) { dict = Py.newStringMap(); } globals = dict; - if (systemState == null) + if (systemState == null) { systemState = Py.getSystemState(); + } this.systemState = systemState; setSystemState(); @@ -120,7 +112,7 @@ PyModule module = new PyModule("__main__", dict); systemState.modules.__setitem__("__main__", module); } - + if (Options.importSite) { // Ensure site-packages are available imp.load("site"); @@ -136,59 +128,116 @@ } /** - * Sets a Python object to use for the standard input stream. + * Sets a Python object to use for the standard input stream, sys.stdin. This + * stream is used in a byte-oriented way, through calls to read and + * readline on the object. * - * @param inStream - * a Python file-like object to use as input stream + * @param inStream a Python file-like object to use as the input stream */ public void setIn(PyObject inStream) { getSystemState().stdin = inStream; } + /** + * Sets a {@link Reader} to use for the standard input stream, sys.stdin. This + * stream is wrapped such that characters will be narrowed to bytes. A character greater than + * U+00FF will raise a Java IllegalArgumentException from within + * {@link PyString}. + * + * @param inStream to use as the input stream + */ public void setIn(java.io.Reader inStream) { setIn(new PyFileReader(inStream)); } /** - * Sets a java.io.InputStream to use for the standard input - * stream. + * Sets a {@link java.io.InputStream} to use for the standard input stream. * - * @param inStream - * InputStream to use as input stream + * @param inStream InputStream to use as input stream */ public void setIn(java.io.InputStream inStream) { setIn(new PyFile(inStream)); } /** - * Sets a Python object to use for the standard output stream. + * Sets a Python object to use for the standard output stream, sys.stdout. This + * stream is used in a byte-oriented way (mostly) that depends on the type of file-like object. + * The behaviour as implemented is: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Python type of object o written
str/bytesunicodeAny other type
{@link PyFile}as bytes directlyrespect {@link PyFile#encoding}call str(o) first
{@link PyFileWriter}each byte value as a charwrite as Java charscall o.toString() first
Other {@link PyObject} finvoke f.write(str(o))invoke f.write(o)invoke f.write(str(o))
* - * @param outStream - * Python file-like object to use as output stream + * @param outStream Python file-like object to use as the output stream */ public void setOut(PyObject outStream) { getSystemState().stdout = outStream; } + /** + * Sets a {@link Writer} to use for the standard output stream, sys.stdout. The + * behaviour as implemented is to output each object o by calling + * o.toString() and writing this as UTF-16. + * + * @param outStream to use as the output stream + */ public void setOut(java.io.Writer outStream) { setOut(new PyFileWriter(outStream)); } /** - * Sets a java.io.OutputStream to use for the standard output - * stream. + * Sets a {@link java.io.OutputStream} to use for the standard output stream. * - * @param outStream - * OutputStream to use as output stream + * @param outStream OutputStream to use as output stream */ public void setOut(java.io.OutputStream outStream) { setOut(new PyFile(outStream)); } + /** + * Sets a Python object to use for the standard output stream, sys.stderr. This + * stream is used in a byte-oriented way (mostly) that depends on the type of file-like object, + * in the same way as {@link #setOut(PyObject)}. + * + * @param outStream Python file-like object to use as the error output stream + */ public void setErr(PyObject outStream) { getSystemState().stderr = outStream; } + /** + * Sets a {@link Writer} to use for the standard output stream, sys.stdout. The + * behaviour as implemented is to output each object o by calling + * o.toString() and writing this as UTF-16. + * + * @param outStream to use as the error output stream + */ public void setErr(java.io.Writer outStream) { setErr(new PyFileWriter(outStream)); } @@ -198,8 +247,7 @@ } /** - * Evaluates a string as a Python expression and returns the - * result. + * Evaluates a string as a Python expression and returns the result. */ public PyObject eval(String s) { setSystemState(); @@ -253,36 +301,35 @@ } /** - * Compiles a string of Python source as either an expression (if - * possible) or a module. + * Compiles a string of Python source as either an expression (if possible) or a module. * - * Designed for use by a JSR 223 implementation: "the Scripting - * API does not distinguish between scripts which return values - * and those which do not, nor do they make the corresponding - * distinction between evaluating or executing objects." - * (SCR.4.2.1) + * Designed for use by a JSR 223 implementation: "the Scripting API does not distinguish between + * scripts which return values and those which do not, nor do they make the corresponding + * distinction between evaluating or executing objects." (SCR.4.2.1) */ public PyCode compile(String script) { return compile(script, "