[Numpy-svn] r8705 - in branches/1.5.x/numpy/core: include/numpy src/multiarray tests

numpy-svn at scipy.org numpy-svn at scipy.org
Sat Sep 11 12:09:34 EDT 2010


Author: ptvirtan
Date: 2010-09-11 11:09:34 -0500 (Sat, 11 Sep 2010)
New Revision: 8705

Modified:
   branches/1.5.x/numpy/core/include/numpy/npy_3kcompat.h
   branches/1.5.x/numpy/core/src/multiarray/methods.c
   branches/1.5.x/numpy/core/src/multiarray/multiarraymodule.c
   branches/1.5.x/numpy/core/tests/test_regression.py
Log:
BUG: (backport r8701) core: sync Python 3 file handle position in tofile/fromfile (fixes #1610)

Modified: branches/1.5.x/numpy/core/include/numpy/npy_3kcompat.h
===================================================================
--- branches/1.5.x/numpy/core/include/numpy/npy_3kcompat.h	2010-09-11 16:09:12 UTC (rev 8704)
+++ branches/1.5.x/numpy/core/include/numpy/npy_3kcompat.h	2010-09-11 16:09:34 UTC (rev 8705)
@@ -149,14 +149,19 @@
 #endif
 
 /*
- * PyFile_AsFile
+ * PyFile_* compatibility
  */
 #if defined(NPY_PY3K)
+
+/*
+ * Get a FILE* handle to the file represented by the Python object
+ */
 static NPY_INLINE FILE*
 npy_PyFile_Dup(PyObject *file, char *mode)
 {
     int fd, fd2;
     PyObject *ret, *os;
+    FILE *handle;
     /* Flush first to ensure things end up in the file in the correct order */
     ret = PyObject_CallMethod(file, "flush", "");
     if (ret == NULL) {
@@ -179,11 +184,62 @@
     fd2 = PyNumber_AsSsize_t(ret, NULL);
     Py_DECREF(ret);
 #ifdef _WIN32
-    return _fdopen(fd2, mode);
+    handle = _fdopen(fd2, mode);
 #else
-    return fdopen(fd2, mode);
+    handle = fdopen(fd2, mode);
 #endif
+    if (handle == NULL) {
+        PyErr_SetString(PyExc_IOError,
+                        "Getting a FILE* from a Python file object failed");
+    }
+    return handle;
 }
+
+/*
+ * Close the dup-ed file handle, and seek the Python one to the current position
+ */
+static NPY_INLINE int
+npy_PyFile_DupClose(PyObject *file, FILE* handle)
+{
+    PyObject *ret;
+    long position;
+    position = ftell(handle);
+    fclose(handle);
+
+    ret = PyObject_CallMethod(file, "seek", "li", position, 0);
+    if (ret == NULL) {
+        return -1;
+    }
+    Py_DECREF(ret);
+    return 0;
+}
+
+static int
+npy_PyFile_Check(PyObject *file)
+{
+    static PyTypeObject *fileio = NULL;
+
+    if (fileio == NULL) {
+        PyObject *mod;
+        mod = PyImport_ImportModule("io");
+        if (mod == NULL) {
+            return 0;
+        }
+        fileio = (PyTypeObject*)PyObject_GetAttrString(mod, "FileIO");
+        Py_DECREF(mod);
+    }
+
+    if (fileio != NULL) {
+        return PyObject_TypeCheck(file, fileio);
+    }
+}
+
+#else
+
+#define npy_PyFile_Dup(file, mode) PyFile_AsFile(file)
+#define npy_PyFile_DupClose(file, handle) (0)
+#define npy_PyFile_Check PyFile_Check
+
 #endif
 
 static NPY_INLINE PyObject*

Modified: branches/1.5.x/numpy/core/src/multiarray/methods.c
===================================================================
--- branches/1.5.x/numpy/core/src/multiarray/methods.c	2010-09-11 16:09:12 UTC (rev 8704)
+++ branches/1.5.x/numpy/core/src/multiarray/methods.c	2010-09-11 16:09:34 UTC (rev 8705)
@@ -496,7 +496,7 @@
 static PyObject *
 array_tofile(PyArrayObject *self, PyObject *args, PyObject *kwds)
 {
-    int ret;
+    int ret, ret2;
     PyObject *file;
     FILE *fd;
     char *sep = "";
@@ -517,11 +517,7 @@
     else {
         Py_INCREF(file);
     }
-#if defined(NPY_PY3K)
     fd = npy_PyFile_Dup(file, "wb");
-#else
-    fd = PyFile_AsFile(file);
-#endif
     if (fd == NULL) {
         PyErr_SetString(PyExc_IOError, "first argument must be a " \
                         "string or open file");
@@ -529,11 +525,9 @@
         return NULL;
     }
     ret = PyArray_ToFile(self, fd, sep, format);
-#if defined(NPY_PY3K)
-    fclose(fd);
-#endif
+    ret2 = npy_PyFile_DupClose(file, fd);
     Py_DECREF(file);
-    if (ret < 0) {
+    if (ret < 0 || ret2 < 0) {
         return NULL;
     }
     Py_INCREF(Py_None);

Modified: branches/1.5.x/numpy/core/src/multiarray/multiarraymodule.c
===================================================================
--- branches/1.5.x/numpy/core/src/multiarray/multiarraymodule.c	2010-09-11 16:09:12 UTC (rev 8704)
+++ branches/1.5.x/numpy/core/src/multiarray/multiarraymodule.c	2010-09-11 16:09:34 UTC (rev 8705)
@@ -1669,6 +1669,7 @@
 array_fromfile(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *keywds)
 {
     PyObject *file = NULL, *ret;
+    int ok;
     FILE *fp;
     char *sep = "";
     Py_ssize_t nin = -1;
@@ -1690,11 +1691,7 @@
     else {
         Py_INCREF(file);
     }
-#if defined(NPY_PY3K)
     fp = npy_PyFile_Dup(file, "rb");
-#else
-    fp = PyFile_AsFile(file);
-#endif
     if (fp == NULL) {
         PyErr_SetString(PyExc_IOError,
                 "first argument must be an open file");
@@ -1705,10 +1702,12 @@
         type = PyArray_DescrFromType(PyArray_DEFAULT);
     }
     ret = PyArray_FromFile(fp, type, (intp) nin, sep);
-#if defined(NPY_PY3K)
-    fclose(fp);
-#endif
+    ok = npy_PyFile_DupClose(file, fp);
     Py_DECREF(file);
+    if (ok < 0) {
+        Py_DECREF(ret);
+        return NULL;
+    }
     return ret;
 }
 

Modified: branches/1.5.x/numpy/core/tests/test_regression.py
===================================================================
--- branches/1.5.x/numpy/core/tests/test_regression.py	2010-09-11 16:09:12 UTC (rev 8704)
+++ branches/1.5.x/numpy/core/tests/test_regression.py	2010-09-11 16:09:34 UTC (rev 8705)
@@ -7,6 +7,7 @@
 from numpy.testing import *
 from numpy.testing.utils import _assert_valid_refcount
 from numpy.compat import asbytes, asunicode, asbytes_nested
+import tempfile
 import numpy as np
 
 if sys.version_info[0] >= 3:
@@ -1377,5 +1378,26 @@
         c2 = sys.getrefcount(rgba)
         assert_equal(c1, c2)
 
+    def test_fromfile_tofile_seeks(self):
+        # On Python 3, tofile/fromfile used to get (#1610) the Python
+        # file handle out of sync
+        f = tempfile.TemporaryFile()
+        f.write(np.arange(255, dtype='u1').tostring())
+
+        f.seek(20)
+        ret = np.fromfile(f, count=4, dtype='u1')
+        assert_equal(ret, np.array([20, 21, 22, 23], dtype='u1'))
+        assert_equal(f.tell(), 24)
+
+        f.seek(40)
+        np.array([1, 2, 3], dtype='u1').tofile(f)
+        assert_equal(f.tell(), 43)
+
+        f.seek(40)
+        data = f.read(3)
+        assert_equal(data, asbytes("\x01\x02\x03"))
+
+        f.close()
+
 if __name__ == "__main__":
     run_module_suite()




More information about the Numpy-svn mailing list