From f5894dd646f5e39918377b37b8c8694cebdca103 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Wed, 16 Nov 2016 15:40:39 +0200 Subject: [PATCH] Issue #28701: Replace _PyUnicode_CompareWithId with _PyUnicode_EqualToASCIIId. The latter function is more readable, faster and doesn't raise exceptions. Based on patch by Xiang Zhang. --- Include/unicodeobject.h | 19 +++++++++++++++++++ Objects/typeobject.c | 12 ++++++------ Objects/unicodeobject.c | 38 ++++++++++++++++++++++++++++++++++++++ Python/errors.c | 2 +- Python/pythonrun.c | 2 +- 5 files changed, 65 insertions(+), 8 deletions(-) diff --git a/Include/unicodeobject.h b/Include/unicodeobject.h index 007c15bd47e..6b6acd73531 100644 --- a/Include/unicodeobject.h +++ b/Include/unicodeobject.h @@ -2000,12 +2000,31 @@ PyAPI_FUNC(int) PyUnicode_Compare( ); #ifndef Py_LIMITED_API +/* Compare a string with an identifier and return -1, 0, 1 for less than, + equal, and greater than, respectively. + Raise an exception and return -1 on error. */ + PyAPI_FUNC(int) _PyUnicode_CompareWithId( PyObject *left, /* Left string */ _Py_Identifier *right /* Right identifier */ ); + +/* Test whether a unicode is equal to ASCII identifier. Return 1 if true, + 0 otherwise. Return 0 if any argument contains non-ASCII characters. + Any error occurs inside will be cleared before return. */ + +PyAPI_FUNC(int) _PyUnicode_EqualToASCIIId( + PyObject *left, /* Left string */ + _Py_Identifier *right /* Right identifier */ + ); #endif +/* Compare a Unicode object with C string and return -1, 0, 1 for less than, + equal, and greater than, respectively. It is best to pass only + ASCII-encoded strings, but the function interprets the input string as + ISO-8859-1 if it contains non-ASCII characters. + Raise an exception and return -1 on error. */ + PyAPI_FUNC(int) PyUnicode_CompareWithASCIIString( PyObject *left, const char *right /* ASCII-encoded string */ diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 28a2db19456..7b76e5cd4d4 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -858,7 +858,7 @@ type_repr(PyTypeObject *type) return NULL; } - if (mod != NULL && _PyUnicode_CompareWithId(mod, &PyId_builtins)) + if (mod != NULL && !_PyUnicode_EqualToASCIIId(mod, &PyId_builtins)) rtn = PyUnicode_FromFormat("", mod, name); else rtn = PyUnicode_FromFormat("", type->tp_name); @@ -2386,7 +2386,7 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) if (!valid_identifier(tmp)) goto error; assert(PyUnicode_Check(tmp)); - if (_PyUnicode_CompareWithId(tmp, &PyId___dict__) == 0) { + if (_PyUnicode_EqualToASCIIId(tmp, &PyId___dict__)) { if (!may_add_dict || add_dict) { PyErr_SetString(PyExc_TypeError, "__dict__ slot disallowed: " @@ -2417,7 +2417,7 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds) for (i = j = 0; i < nslots; i++) { tmp = PyTuple_GET_ITEM(slots, i); if ((add_dict && - _PyUnicode_CompareWithId(tmp, &PyId___dict__) == 0) || + _PyUnicode_EqualToASCIIId(tmp, &PyId___dict__)) || (add_weak && _PyUnicode_EqualToASCIIString(tmp, "__weakref__"))) continue; @@ -3490,7 +3490,7 @@ object_repr(PyObject *self) Py_XDECREF(mod); return NULL; } - if (mod != NULL && _PyUnicode_CompareWithId(mod, &PyId_builtins)) + if (mod != NULL && !_PyUnicode_EqualToASCIIId(mod, &PyId_builtins)) rtn = PyUnicode_FromFormat("<%U.%U object at %p>", mod, name, self); else rtn = PyUnicode_FromFormat("<%s object at %p>", @@ -7107,7 +7107,7 @@ super_getattro(PyObject *self, PyObject *name) (i.e. super, or a subclass), not the class of su->obj. */ if (PyUnicode_Check(name) && PyUnicode_GET_LENGTH(name) == 9 && - _PyUnicode_CompareWithId(name, &PyId___class__) == 0) + _PyUnicode_EqualToASCIIId(name, &PyId___class__)) goto skip; mro = starttype->tp_mro; @@ -7319,7 +7319,7 @@ super_init(PyObject *self, PyObject *args, PyObject *kwds) for (i = 0; i < n; i++) { PyObject *name = PyTuple_GET_ITEM(co->co_freevars, i); assert(PyUnicode_Check(name)); - if (!_PyUnicode_CompareWithId(name, &PyId___class__)) { + if (_PyUnicode_EqualToASCIIId(name, &PyId___class__)) { Py_ssize_t index = co->co_nlocals + PyTuple_GET_SIZE(co->co_cellvars) + i; PyObject *cell = f->f_localsplus[index]; diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c index 86485bdb6a1..15705e10f9d 100644 --- a/Objects/unicodeobject.c +++ b/Objects/unicodeobject.c @@ -10869,6 +10869,44 @@ _PyUnicode_EqualToASCIIString(PyObject *unicode, const char *str) memcmp(PyUnicode_1BYTE_DATA(unicode), str, len) == 0; } +int +_PyUnicode_EqualToASCIIId(PyObject *left, _Py_Identifier *right) +{ + PyObject *right_uni; + Py_hash_t hash; + + assert(_PyUnicode_CHECK(left)); + assert(right->string); + + if (PyUnicode_READY(left) == -1) { + /* memory error or bad data */ + PyErr_Clear(); + return non_ready_unicode_equal_to_ascii_string(left, right->string); + } + + if (!PyUnicode_IS_ASCII(left)) + return 0; + + right_uni = _PyUnicode_FromId(right); /* borrowed */ + if (right_uni == NULL) { + /* memory error or bad data */ + PyErr_Clear(); + return _PyUnicode_EqualToASCIIString(left, right->string); + } + + if (left == right_uni) + return 1; + + if (PyUnicode_CHECK_INTERNED(left)) + return 0; + + assert(_PyUnicode_HASH(right_uni) != 1); + hash = _PyUnicode_HASH(left); + if (hash != -1 && hash != _PyUnicode_HASH(right_uni)) + return 0; + + return unicode_compare_eq(left, right_uni); +} #define TEST_COND(cond) \ ((cond) ? Py_True : Py_False) diff --git a/Python/errors.c b/Python/errors.c index 6cc0c20cd55..dd014485188 100644 --- a/Python/errors.c +++ b/Python/errors.c @@ -934,7 +934,7 @@ PyErr_WriteUnraisable(PyObject *obj) goto done; } else { - if (_PyUnicode_CompareWithId(moduleName, &PyId_builtins) != 0) { + if (!_PyUnicode_EqualToASCIIId(moduleName, &PyId_builtins)) { if (PyFile_WriteObject(moduleName, f, Py_PRINT_RAW) < 0) goto done; if (PyFile_WriteString(".", f) < 0) diff --git a/Python/pythonrun.c b/Python/pythonrun.c index 7fbf06e68a1..72b6c9b0608 100644 --- a/Python/pythonrun.c +++ b/Python/pythonrun.c @@ -747,7 +747,7 @@ print_exception(PyObject *f, PyObject *value) err = PyFile_WriteString("", f); } else { - if (_PyUnicode_CompareWithId(moduleName, &PyId_builtins) != 0) + if (!_PyUnicode_EqualToASCIIId(moduleName, &PyId_builtins)) { err = PyFile_WriteObject(moduleName, f, Py_PRINT_RAW); err += PyFile_WriteString(".", f);