gh-112075: Add critical sections for most dict APIs (#114508)

Starts adding thread safety to dict objects.


Use @critical_section for APIs which are exposed via argument clinic and don't directly correlate with a public C API which needs to acquire the lock
Use a _lock_held suffix for keeping changes to complicated functions simple and just wrapping them with a critical section
Acquire and release the lock in an existing function where it won't be overly disruptive to the existing logic
This commit is contained in:
Dino Viehland 2024-02-06 14:03:43 -08:00 committed by GitHub
parent b6228b521b
commit 92abb01240
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 782 additions and 284 deletions

View File

@ -104,12 +104,37 @@ extern "C" {
# define Py_END_CRITICAL_SECTION2() \ # define Py_END_CRITICAL_SECTION2() \
_PyCriticalSection2_End(&_cs2); \ _PyCriticalSection2_End(&_cs2); \
} }
// Asserts that the mutex is locked. The mutex must be held by the
// top-most critical section otherwise there's the possibility
// that the mutex would be swalled out in some code paths.
#define _Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(mutex) \
_PyCriticalSection_AssertHeld(mutex)
// Asserts that the mutex for the given object is locked. The mutex must
// be held by the top-most critical section otherwise there's the
// possibility that the mutex would be swalled out in some code paths.
#ifdef Py_DEBUG
#define _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(op) \
if (Py_REFCNT(op) != 1) { \
_Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&_PyObject_CAST(op)->ob_mutex); \
}
#else /* Py_DEBUG */
#define _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(op)
#endif /* Py_DEBUG */
#else /* !Py_GIL_DISABLED */ #else /* !Py_GIL_DISABLED */
// The critical section APIs are no-ops with the GIL. // The critical section APIs are no-ops with the GIL.
# define Py_BEGIN_CRITICAL_SECTION(op) # define Py_BEGIN_CRITICAL_SECTION(op)
# define Py_END_CRITICAL_SECTION() # define Py_END_CRITICAL_SECTION()
# define Py_BEGIN_CRITICAL_SECTION2(a, b) # define Py_BEGIN_CRITICAL_SECTION2(a, b)
# define Py_END_CRITICAL_SECTION2() # define Py_END_CRITICAL_SECTION2()
# define _Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(mutex)
# define _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(op)
#endif /* !Py_GIL_DISABLED */ #endif /* !Py_GIL_DISABLED */
typedef struct { typedef struct {
@ -236,6 +261,27 @@ _PyCriticalSection2_End(_PyCriticalSection2 *c)
PyAPI_FUNC(void) PyAPI_FUNC(void)
_PyCriticalSection_SuspendAll(PyThreadState *tstate); _PyCriticalSection_SuspendAll(PyThreadState *tstate);
#ifdef Py_GIL_DISABLED
static inline void
_PyCriticalSection_AssertHeld(PyMutex *mutex) {
#ifdef Py_DEBUG
PyThreadState *tstate = _PyThreadState_GET();
uintptr_t prev = tstate->critical_section;
if (prev & _Py_CRITICAL_SECTION_TWO_MUTEXES) {
_PyCriticalSection2 *cs = (_PyCriticalSection2 *)(prev & ~_Py_CRITICAL_SECTION_MASK);
assert(cs != NULL && (cs->base.mutex == mutex || cs->mutex2 == mutex));
}
else {
_PyCriticalSection *cs = (_PyCriticalSection *)(tstate->critical_section & ~_Py_CRITICAL_SECTION_MASK);
assert(cs != NULL && cs->mutex == mutex);
}
#endif
}
#endif
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -39,13 +39,14 @@ static const char copyright[] =
" SRE 2.2.2 Copyright (c) 1997-2002 by Secret Labs AB "; " SRE 2.2.2 Copyright (c) 1997-2002 by Secret Labs AB ";
#include "Python.h" #include "Python.h"
#include "pycore_dict.h" // _PyDict_Next() #include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION
#include "pycore_long.h" // _PyLong_GetZero() #include "pycore_dict.h" // _PyDict_Next()
#include "pycore_moduleobject.h" // _PyModule_GetState() #include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "sre.h" // SRE_CODE #include "sre.h" // SRE_CODE
#include <ctype.h> // tolower(), toupper(), isalnum() #include <ctype.h> // tolower(), toupper(), isalnum()
#define SRE_CODE_BITS (8 * sizeof(SRE_CODE)) #define SRE_CODE_BITS (8 * sizeof(SRE_CODE))
@ -2349,26 +2350,28 @@ _sre_SRE_Match_groupdict_impl(MatchObject *self, PyObject *default_value)
if (!result || !self->pattern->groupindex) if (!result || !self->pattern->groupindex)
return result; return result;
Py_BEGIN_CRITICAL_SECTION(self->pattern->groupindex);
while (_PyDict_Next(self->pattern->groupindex, &pos, &key, &value, &hash)) { while (_PyDict_Next(self->pattern->groupindex, &pos, &key, &value, &hash)) {
int status; int status;
Py_INCREF(key); Py_INCREF(key);
value = match_getslice(self, key, default_value); value = match_getslice(self, key, default_value);
if (!value) { if (!value) {
Py_DECREF(key); Py_DECREF(key);
goto failed; Py_CLEAR(result);
goto exit;
} }
status = _PyDict_SetItem_KnownHash(result, key, value, hash); status = _PyDict_SetItem_KnownHash(result, key, value, hash);
Py_DECREF(value); Py_DECREF(value);
Py_DECREF(key); Py_DECREF(key);
if (status < 0) if (status < 0) {
goto failed; Py_CLEAR(result);
goto exit;
}
} }
exit:
Py_END_CRITICAL_SECTION();
return result; return result;
failed:
Py_DECREF(result);
return NULL;
} }
/*[clinic input] /*[clinic input]

View File

@ -2,6 +2,7 @@
preserve preserve
[clinic start generated code]*/ [clinic start generated code]*/
#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION()
#include "pycore_modsupport.h" // _PyArg_CheckPositional() #include "pycore_modsupport.h" // _PyArg_CheckPositional()
PyDoc_STRVAR(dict_fromkeys__doc__, PyDoc_STRVAR(dict_fromkeys__doc__,
@ -65,6 +66,21 @@ PyDoc_STRVAR(dict___contains____doc__,
#define DICT___CONTAINS___METHODDEF \ #define DICT___CONTAINS___METHODDEF \
{"__contains__", (PyCFunction)dict___contains__, METH_O|METH_COEXIST, dict___contains____doc__}, {"__contains__", (PyCFunction)dict___contains__, METH_O|METH_COEXIST, dict___contains____doc__},
static PyObject *
dict___contains___impl(PyDictObject *self, PyObject *key);
static PyObject *
dict___contains__(PyDictObject *self, PyObject *key)
{
PyObject *return_value = NULL;
Py_BEGIN_CRITICAL_SECTION(self);
return_value = dict___contains___impl(self, key);
Py_END_CRITICAL_SECTION();
return return_value;
}
PyDoc_STRVAR(dict_get__doc__, PyDoc_STRVAR(dict_get__doc__,
"get($self, key, default=None, /)\n" "get($self, key, default=None, /)\n"
"--\n" "--\n"
@ -93,7 +109,9 @@ dict_get(PyDictObject *self, PyObject *const *args, Py_ssize_t nargs)
} }
default_value = args[1]; default_value = args[1];
skip_optional: skip_optional:
Py_BEGIN_CRITICAL_SECTION(self);
return_value = dict_get_impl(self, key, default_value); return_value = dict_get_impl(self, key, default_value);
Py_END_CRITICAL_SECTION();
exit: exit:
return return_value; return return_value;
@ -130,7 +148,9 @@ dict_setdefault(PyDictObject *self, PyObject *const *args, Py_ssize_t nargs)
} }
default_value = args[1]; default_value = args[1];
skip_optional: skip_optional:
Py_BEGIN_CRITICAL_SECTION(self);
return_value = dict_setdefault_impl(self, key, default_value); return_value = dict_setdefault_impl(self, key, default_value);
Py_END_CRITICAL_SECTION();
exit: exit:
return return_value; return return_value;
@ -209,7 +229,13 @@ dict_popitem_impl(PyDictObject *self);
static PyObject * static PyObject *
dict_popitem(PyDictObject *self, PyObject *Py_UNUSED(ignored)) dict_popitem(PyDictObject *self, PyObject *Py_UNUSED(ignored))
{ {
return dict_popitem_impl(self); PyObject *return_value = NULL;
Py_BEGIN_CRITICAL_SECTION(self);
return_value = dict_popitem_impl(self);
Py_END_CRITICAL_SECTION();
return return_value;
} }
PyDoc_STRVAR(dict___sizeof____doc__, PyDoc_STRVAR(dict___sizeof____doc__,
@ -301,4 +327,4 @@ dict_values(PyDictObject *self, PyObject *Py_UNUSED(ignored))
{ {
return dict_values_impl(self); return dict_values_impl(self);
} }
/*[clinic end generated code: output=f3ac47dfbf341b23 input=a9049054013a1b77]*/ /*[clinic end generated code: output=c8fda06bac5b05f3 input=a9049054013a1b77]*/

File diff suppressed because it is too large Load Diff

View File

@ -465,12 +465,13 @@ later:
*/ */
#include "Python.h" #include "Python.h"
#include "pycore_call.h" // _PyObject_CallNoArgs() #include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_ceval.h" // _PyEval_GetBuiltin() #include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_dict.h" // _Py_dict_lookup() #include "pycore_critical_section.h" //_Py_BEGIN_CRITICAL_SECTION
#include "pycore_object.h" // _PyObject_GC_UNTRACK() #include "pycore_dict.h" // _Py_dict_lookup()
#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1() #include "pycore_object.h" // _PyObject_GC_UNTRACK()
#include <stddef.h> // offsetof() #include "pycore_pyerrors.h" // _PyErr_ChainExceptions1()
#include <stddef.h> // offsetof()
#include "clinic/odictobject.c.h" #include "clinic/odictobject.c.h"
@ -1039,6 +1040,8 @@ _odict_popkey_hash(PyObject *od, PyObject *key, PyObject *failobj,
{ {
PyObject *value = NULL; PyObject *value = NULL;
Py_BEGIN_CRITICAL_SECTION(od);
_ODictNode *node = _odict_find_node_hash((PyODictObject *)od, key, hash); _ODictNode *node = _odict_find_node_hash((PyODictObject *)od, key, hash);
if (node != NULL) { if (node != NULL) {
/* Pop the node first to avoid a possible dict resize (due to /* Pop the node first to avoid a possible dict resize (due to
@ -1046,7 +1049,7 @@ _odict_popkey_hash(PyObject *od, PyObject *key, PyObject *failobj,
resolution. */ resolution. */
int res = _odict_clear_node((PyODictObject *)od, node, key, hash); int res = _odict_clear_node((PyODictObject *)od, node, key, hash);
if (res < 0) { if (res < 0) {
return NULL; goto done;
} }
/* Now delete the value from the dict. */ /* Now delete the value from the dict. */
if (_PyDict_Pop_KnownHash((PyDictObject *)od, key, hash, if (_PyDict_Pop_KnownHash((PyDictObject *)od, key, hash,
@ -1063,6 +1066,8 @@ _odict_popkey_hash(PyObject *od, PyObject *key, PyObject *failobj,
PyErr_SetObject(PyExc_KeyError, key); PyErr_SetObject(PyExc_KeyError, key);
} }
} }
Py_END_CRITICAL_SECTION();
done:
return value; return value;
} }

View File

@ -32,13 +32,14 @@
*/ */
#include "Python.h" #include "Python.h"
#include "pycore_ceval.h" // _PyEval_GetBuiltin() #include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_dict.h" // _PyDict_Contains_KnownHash() #include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION, Py_END_CRITICAL_SECTION
#include "pycore_modsupport.h" // _PyArg_NoKwnames() #include "pycore_dict.h" // _PyDict_Contains_KnownHash()
#include "pycore_object.h" // _PyObject_GC_UNTRACK() #include "pycore_modsupport.h" // _PyArg_NoKwnames()
#include "pycore_pyerrors.h" // _PyErr_SetKeyError() #include "pycore_object.h" // _PyObject_GC_UNTRACK()
#include "pycore_setobject.h" // _PySet_NextEntry() definition #include "pycore_pyerrors.h" // _PyErr_SetKeyError()
#include <stddef.h> // offsetof() #include "pycore_setobject.h" // _PySet_NextEntry() definition
#include <stddef.h> // offsetof()
/* Object used as dummy key to fill deleted entries */ /* Object used as dummy key to fill deleted entries */
static PyObject _dummy_struct; static PyObject _dummy_struct;
@ -903,11 +904,17 @@ set_update_internal(PySetObject *so, PyObject *other)
if (set_table_resize(so, (so->used + dictsize)*2) != 0) if (set_table_resize(so, (so->used + dictsize)*2) != 0)
return -1; return -1;
} }
int err = 0;
Py_BEGIN_CRITICAL_SECTION(other);
while (_PyDict_Next(other, &pos, &key, &value, &hash)) { while (_PyDict_Next(other, &pos, &key, &value, &hash)) {
if (set_add_entry(so, key, hash)) if (set_add_entry(so, key, hash)) {
return -1; err = -1;
goto exit;
}
} }
return 0; exit:
Py_END_CRITICAL_SECTION();
return err;
} }
it = PyObject_GetIter(other); it = PyObject_GetIter(other);
@ -1620,6 +1627,33 @@ set_isub(PySetObject *so, PyObject *other)
return Py_NewRef(so); return Py_NewRef(so);
} }
static PyObject *
set_symmetric_difference_update_dict(PySetObject *so, PyObject *other)
{
PyObject *key;
Py_ssize_t pos = 0;
Py_hash_t hash;
PyObject *value;
int rv;
while (_PyDict_Next(other, &pos, &key, &value, &hash)) {
Py_INCREF(key);
rv = set_discard_entry(so, key, hash);
if (rv < 0) {
Py_DECREF(key);
return NULL;
}
if (rv == DISCARD_NOTFOUND) {
if (set_add_entry(so, key, hash)) {
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
Py_RETURN_NONE;
}
static PyObject * static PyObject *
set_symmetric_difference_update(PySetObject *so, PyObject *other) set_symmetric_difference_update(PySetObject *so, PyObject *other)
{ {
@ -1634,23 +1668,13 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other)
return set_clear(so, NULL); return set_clear(so, NULL);
if (PyDict_CheckExact(other)) { if (PyDict_CheckExact(other)) {
PyObject *value; PyObject *res;
while (_PyDict_Next(other, &pos, &key, &value, &hash)) {
Py_INCREF(key); Py_BEGIN_CRITICAL_SECTION(other);
rv = set_discard_entry(so, key, hash); res = set_symmetric_difference_update_dict(so, other);
if (rv < 0) { Py_END_CRITICAL_SECTION();
Py_DECREF(key);
return NULL; return res;
}
if (rv == DISCARD_NOTFOUND) {
if (set_add_entry(so, key, hash)) {
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
Py_RETURN_NONE;
} }
if (PyAnySet_Check(other)) { if (PyAnySet_Check(other)) {