[Python-checkins] bpo-32630: Use contextvars in decimal (GH-5278)

Yury Selivanov webhook-mailer at python.org
Sat Jan 27 13:46:49 EST 2018


https://github.com/python/cpython/commit/f13f12d8daa587b5fcc66fe3ed1090a5dadab289
commit: f13f12d8daa587b5fcc66fe3ed1090a5dadab289
branch: master
author: Yury Selivanov <yury at magic.io>
committer: GitHub <noreply at github.com>
date: 2018-01-27T13:46:46-05:00
summary:

bpo-32630: Use contextvars in decimal (GH-5278)

files:
A Lib/test/test_asyncio/test_context.py
A Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst
M Lib/_pydecimal.py
M Modules/_decimal/_decimal.c

diff --git a/Lib/_pydecimal.py b/Lib/_pydecimal.py
index a1662bbd671..359690003fe 100644
--- a/Lib/_pydecimal.py
+++ b/Lib/_pydecimal.py
@@ -433,13 +433,11 @@ class FloatOperation(DecimalException, TypeError):
 # The getcontext() and setcontext() function manage access to a thread-local
 # current context.
 
-import threading
+import contextvars
 
-local = threading.local()
-if hasattr(local, '__decimal_context__'):
-    del local.__decimal_context__
+_current_context_var = contextvars.ContextVar('decimal_context')
 
-def getcontext(_local=local):
+def getcontext():
     """Returns this thread's context.
 
     If this thread does not yet have a context, returns
@@ -447,20 +445,20 @@ def getcontext(_local=local):
     New contexts are copies of DefaultContext.
     """
     try:
-        return _local.__decimal_context__
-    except AttributeError:
+        return _current_context_var.get()
+    except LookupError:
         context = Context()
-        _local.__decimal_context__ = context
+        _current_context_var.set(context)
         return context
 
-def setcontext(context, _local=local):
+def setcontext(context):
     """Set this thread's context to context."""
     if context in (DefaultContext, BasicContext, ExtendedContext):
         context = context.copy()
         context.clear_flags()
-    _local.__decimal_context__ = context
+    _current_context_var.set(context)
 
-del threading, local        # Don't contaminate the namespace
+del contextvars        # Don't contaminate the namespace
 
 def localcontext(ctx=None):
     """Return a context manager for a copy of the supplied context
diff --git a/Lib/test/test_asyncio/test_context.py b/Lib/test/test_asyncio/test_context.py
new file mode 100644
index 00000000000..6abddd9f251
--- /dev/null
+++ b/Lib/test/test_asyncio/test_context.py
@@ -0,0 +1,29 @@
+import asyncio
+import decimal
+import unittest
+
+
+class DecimalContextTest(unittest.TestCase):
+
+    def test_asyncio_task_decimal_context(self):
+        async def fractions(t, precision, x, y):
+            with decimal.localcontext() as ctx:
+                ctx.prec = precision
+                a = decimal.Decimal(x) / decimal.Decimal(y)
+                await asyncio.sleep(t)
+                b = decimal.Decimal(x) / decimal.Decimal(y ** 2)
+                return a, b
+
+        async def main():
+            r1, r2 = await asyncio.gather(
+                fractions(0.1, 3, 1, 3), fractions(0.2, 6, 1, 3))
+
+            return r1, r2
+
+        r1, r2 = asyncio.run(main())
+
+        self.assertEqual(str(r1[0]), '0.333')
+        self.assertEqual(str(r1[1]), '0.111')
+
+        self.assertEqual(str(r2[0]), '0.333333')
+        self.assertEqual(str(r2[1]), '0.111111')
diff --git a/Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst b/Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst
new file mode 100644
index 00000000000..1bbcbb173eb
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst
@@ -0,0 +1 @@
+Refactor decimal module to use contextvars to store decimal context.
diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c
index 18fa2e4fa5e..fddb39ef652 100644
--- a/Modules/_decimal/_decimal.c
+++ b/Modules/_decimal/_decimal.c
@@ -122,10 +122,7 @@ incr_false(void)
 }
 
 
-/* Key for thread state dictionary */
-static PyObject *tls_context_key = NULL;
-/* Invariant: NULL or the most recently accessed thread local context */
-static PyDecContextObject *cached_context = NULL;
+static PyContextVar *current_context_var;
 
 /* Template for creating new thread contexts, calling Context() without
  * arguments and initializing the module_context on first access. */
@@ -1220,10 +1217,6 @@ context_new(PyTypeObject *type, PyObject *args UNUSED, PyObject *kwds UNUSED)
 static void
 context_dealloc(PyDecContextObject *self)
 {
-    if (self == cached_context) {
-        cached_context = NULL;
-    }
-
     Py_XDECREF(self->traps);
     Py_XDECREF(self->flags);
     Py_TYPE(self)->tp_free(self);
@@ -1498,69 +1491,38 @@ static PyGetSetDef context_getsets [] =
  * operation.
  */
 
-/* Get the context from the thread state dictionary. */
 static PyObject *
-current_context_from_dict(void)
+init_current_context(void)
 {
-    PyObject *dict;
-    PyObject *tl_context;
-    PyThreadState *tstate;
-
-    dict = PyThreadState_GetDict();
-    if (dict == NULL) {
-        PyErr_SetString(PyExc_RuntimeError,
-            "cannot get thread state");
+    PyObject *tl_context = context_copy(default_context_template, NULL);
+    if (tl_context == NULL) {
         return NULL;
     }
+    CTX(tl_context)->status = 0;
 
-    tl_context = PyDict_GetItemWithError(dict, tls_context_key);
-    if (tl_context != NULL) {
-        /* We already have a thread local context. */
-        CONTEXT_CHECK(tl_context);
-    }
-    else {
-        if (PyErr_Occurred()) {
-            return NULL;
-        }
-
-        /* Set up a new thread local context. */
-        tl_context = context_copy(default_context_template, NULL);
-        if (tl_context == NULL) {
-            return NULL;
-        }
-        CTX(tl_context)->status = 0;
-
-        if (PyDict_SetItem(dict, tls_context_key, tl_context) < 0) {
-            Py_DECREF(tl_context);
-            return NULL;
-        }
+    PyContextToken *tok = PyContextVar_Set(current_context_var, tl_context);
+    if (tok == NULL) {
         Py_DECREF(tl_context);
+        return NULL;
     }
+    Py_DECREF(tok);
 
-    /* Cache the context of the current thread, assuming that it
-     * will be accessed several times before a thread switch. */
-    tstate = PyThreadState_GET();
-    if (tstate) {
-        cached_context = (PyDecContextObject *)tl_context;
-        cached_context->tstate = tstate;
-    }
-
-    /* Borrowed reference with refcount==1 */
     return tl_context;
 }
 
-/* Return borrowed reference to thread local context. */
-static PyObject *
+static inline PyObject *
 current_context(void)
 {
-    PyThreadState *tstate;
+    PyObject *tl_context;
+    if (PyContextVar_Get(current_context_var, NULL, &tl_context) < 0) {
+        return NULL;
+    }
 
-    tstate = PyThreadState_GET();
-    if (cached_context && cached_context->tstate == tstate) {
-        return (PyObject *)cached_context;
+    if (tl_context != NULL) {
+        return tl_context;
     }
 
-    return current_context_from_dict();
+    return init_current_context();
 }
 
 /* ctxobj := borrowed reference to the current context */
@@ -1568,47 +1530,22 @@ current_context(void)
     ctxobj = current_context(); \
     if (ctxobj == NULL) {       \
         return NULL;            \
-    }
-
-/* ctx := pointer to the mpd_context_t struct of the current context */
-#define CURRENT_CONTEXT_ADDR(ctx) { \
-    PyObject *_c_t_x_o_b_j = current_context(); \
-    if (_c_t_x_o_b_j == NULL) {                 \
-        return NULL;                            \
-    }                                           \
-    ctx = CTX(_c_t_x_o_b_j);                    \
-}
+    }                           \
+    Py_DECREF(ctxobj);
 
 /* Return a new reference to the current context */
 static PyObject *
 PyDec_GetCurrentContext(PyObject *self UNUSED, PyObject *args UNUSED)
 {
-    PyObject *context;
-
-    context = current_context();
-    if (context == NULL) {
-        return NULL;
-    }
-
-    Py_INCREF(context);
-    return context;
+    return current_context();
 }
 
 /* Set the thread local context to a new context, decrement old reference */
 static PyObject *
 PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
 {
-    PyObject *dict;
-
     CONTEXT_CHECK(v);
 
-    dict = PyThreadState_GetDict();
-    if (dict == NULL) {
-        PyErr_SetString(PyExc_RuntimeError,
-            "cannot get thread state");
-        return NULL;
-    }
-
     /* If the new context is one of the templates, make a copy.
      * This is the current behavior of decimal.py. */
     if (v == default_context_template ||
@@ -1624,13 +1561,13 @@ PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
         Py_INCREF(v);
     }
 
-    cached_context = NULL;
-    if (PyDict_SetItem(dict, tls_context_key, v) < 0) {
-        Py_DECREF(v);
+    PyContextToken *tok = PyContextVar_Set(current_context_var, v);
+    Py_DECREF(v);
+    if (tok == NULL) {
         return NULL;
     }
+    Py_DECREF(tok);
 
-    Py_DECREF(v);
     Py_RETURN_NONE;
 }
 
@@ -4458,6 +4395,7 @@ _dec_hash(PyDecObject *v)
     if (context == NULL) {
         return -1;
     }
+    Py_DECREF(context);
 
     if (mpd_isspecial(MPD(v))) {
         if (mpd_issnan(MPD(v))) {
@@ -5599,6 +5537,11 @@ PyInit__decimal(void)
     mpd_free = PyMem_Free;
     mpd_setminalloc(_Py_DEC_MINALLOC);
 
+    /* Init context variable */
+    current_context_var = PyContextVar_New("decimal_context", NULL);
+    if (current_context_var == NULL) {
+        goto error;
+    }
 
     /* Init external C-API functions */
     _py_long_multiply = PyLong_Type.tp_as_number->nb_multiply;
@@ -5768,7 +5711,6 @@ PyInit__decimal(void)
     CHECK_INT(PyModule_AddObject(m, "DefaultContext",
                                  default_context_template));
 
-    ASSIGN_PTR(tls_context_key, PyUnicode_FromString("___DECIMAL_CTX__"));
     Py_INCREF(Py_True);
     CHECK_INT(PyModule_AddObject(m, "HAVE_THREADS", Py_True));
 
@@ -5827,9 +5769,9 @@ PyInit__decimal(void)
     Py_CLEAR(SignalTuple); /* GCOV_NOT_REACHED */
     Py_CLEAR(DecimalTuple); /* GCOV_NOT_REACHED */
     Py_CLEAR(default_context_template); /* GCOV_NOT_REACHED */
-    Py_CLEAR(tls_context_key); /* GCOV_NOT_REACHED */
     Py_CLEAR(basic_context_template); /* GCOV_NOT_REACHED */
     Py_CLEAR(extended_context_template); /* GCOV_NOT_REACHED */
+    Py_CLEAR(current_context_var); /* GCOV_NOT_REACHED */
     Py_CLEAR(m); /* GCOV_NOT_REACHED */
 
     return NULL; /* GCOV_NOT_REACHED */



More information about the Python-checkins mailing list