summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Include/unicodeobject.h19
-rw-r--r--Objects/typeobject.c12
-rw-r--r--Objects/unicodeobject.c38
-rw-r--r--Python/errors.c2
-rw-r--r--Python/pythonrun.c2
5 files changed, 65 insertions, 8 deletions
diff --git a/Include/unicodeobject.h b/Include/unicodeobject.h
index 007c15bd47..6b6acd7353 100644
--- a/Include/unicodeobject.h
+++ b/Include/unicodeobject.h
@@ -2000,12 +2000,31 @@ PyAPI_FUNC(int) PyUnicode_Compare(
);
#ifndef Py_LIMITED_API
+/* Compare a string with an identifier and return -1, 0, 1 for less than,
+ equal, and greater than, respectively.
+ Raise an exception and return -1 on error. */
+
PyAPI_FUNC(int) _PyUnicode_CompareWithId(
PyObject *left, /* Left string */
_Py_Identifier *right /* Right identifier */
);
+
+/* Test whether a unicode is equal to ASCII identifier. Return 1 if true,
+ 0 otherwise. Return 0 if any argument contains non-ASCII characters.
+ Any error occurs inside will be cleared before return. */
+
+PyAPI_FUNC(int) _PyUnicode_EqualToASCIIId(
+ PyObject *left, /* Left string */
+ _Py_Identifier *right /* Right identifier */
+ );
#endif
+/* Compare a Unicode object with C string and return -1, 0, 1 for less than,
+ equal, and greater than, respectively. It is best to pass only
+ ASCII-encoded strings, but the function interprets the input string as
+ ISO-8859-1 if it contains non-ASCII characters.
+ Raise an exception and return -1 on error. */
+
PyAPI_FUNC(int) PyUnicode_CompareWithASCIIString(
PyObject *left,
const char *right /* ASCII-encoded string */
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index 28a2db1945..7b76e5cd4d 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -858,7 +858,7 @@ type_repr(PyTypeObject *type)
return NULL;
}
- if (mod != NULL && _PyUnicode_CompareWithId(mod, &PyId_builtins))
+ if (mod != NULL && !_PyUnicode_EqualToASCIIId(mod, &PyId_builtins))
rtn = PyUnicode_FromFormat("<class '%U.%U'>", mod, name);
else
rtn = PyUnicode_FromFormat("<class '%s'>", type->tp_name);
@@ -2386,7 +2386,7 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds)
if (!valid_identifier(tmp))
goto error;
assert(PyUnicode_Check(tmp));
- if (_PyUnicode_CompareWithId(tmp, &PyId___dict__) == 0) {
+ if (_PyUnicode_EqualToASCIIId(tmp, &PyId___dict__)) {
if (!may_add_dict || add_dict) {
PyErr_SetString(PyExc_TypeError,
"__dict__ slot disallowed: "
@@ -2417,7 +2417,7 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds)
for (i = j = 0; i < nslots; i++) {
tmp = PyTuple_GET_ITEM(slots, i);
if ((add_dict &&
- _PyUnicode_CompareWithId(tmp, &PyId___dict__) == 0) ||
+ _PyUnicode_EqualToASCIIId(tmp, &PyId___dict__)) ||
(add_weak &&
_PyUnicode_EqualToASCIIString(tmp, "__weakref__")))
continue;
@@ -3490,7 +3490,7 @@ object_repr(PyObject *self)
Py_XDECREF(mod);
return NULL;
}
- if (mod != NULL && _PyUnicode_CompareWithId(mod, &PyId_builtins))
+ if (mod != NULL && !_PyUnicode_EqualToASCIIId(mod, &PyId_builtins))
rtn = PyUnicode_FromFormat("<%U.%U object at %p>", mod, name, self);
else
rtn = PyUnicode_FromFormat("<%s object at %p>",
@@ -7107,7 +7107,7 @@ super_getattro(PyObject *self, PyObject *name)
(i.e. super, or a subclass), not the class of su->obj. */
if (PyUnicode_Check(name) &&
PyUnicode_GET_LENGTH(name) == 9 &&
- _PyUnicode_CompareWithId(name, &PyId___class__) == 0)
+ _PyUnicode_EqualToASCIIId(name, &PyId___class__))
goto skip;
mro = starttype->tp_mro;
@@ -7319,7 +7319,7 @@ super_init(PyObject *self, PyObject *args, PyObject *kwds)
for (i = 0; i < n; i++) {
PyObject *name = PyTuple_GET_ITEM(co->co_freevars, i);
assert(PyUnicode_Check(name));
- if (!_PyUnicode_CompareWithId(name, &PyId___class__)) {
+ if (_PyUnicode_EqualToASCIIId(name, &PyId___class__)) {
Py_ssize_t index = co->co_nlocals +
PyTuple_GET_SIZE(co->co_cellvars) + i;
PyObject *cell = f->f_localsplus[index];
diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c
index 86485bdb6a..15705e10f9 100644
--- a/Objects/unicodeobject.c
+++ b/Objects/unicodeobject.c
@@ -10869,6 +10869,44 @@ _PyUnicode_EqualToASCIIString(PyObject *unicode, const char *str)
memcmp(PyUnicode_1BYTE_DATA(unicode), str, len) == 0;
}
+int
+_PyUnicode_EqualToASCIIId(PyObject *left, _Py_Identifier *right)
+{
+ PyObject *right_uni;
+ Py_hash_t hash;
+
+ assert(_PyUnicode_CHECK(left));
+ assert(right->string);
+
+ if (PyUnicode_READY(left) == -1) {
+ /* memory error or bad data */
+ PyErr_Clear();
+ return non_ready_unicode_equal_to_ascii_string(left, right->string);
+ }
+
+ if (!PyUnicode_IS_ASCII(left))
+ return 0;
+
+ right_uni = _PyUnicode_FromId(right); /* borrowed */
+ if (right_uni == NULL) {
+ /* memory error or bad data */
+ PyErr_Clear();
+ return _PyUnicode_EqualToASCIIString(left, right->string);
+ }
+
+ if (left == right_uni)
+ return 1;
+
+ if (PyUnicode_CHECK_INTERNED(left))
+ return 0;
+
+ assert(_PyUnicode_HASH(right_uni) != 1);
+ hash = _PyUnicode_HASH(left);
+ if (hash != -1 && hash != _PyUnicode_HASH(right_uni))
+ return 0;
+
+ return unicode_compare_eq(left, right_uni);
+}
#define TEST_COND(cond) \
((cond) ? Py_True : Py_False)
diff --git a/Python/errors.c b/Python/errors.c
index 6cc0c20cd5..dd01448518 100644
--- a/Python/errors.c
+++ b/Python/errors.c
@@ -934,7 +934,7 @@ PyErr_WriteUnraisable(PyObject *obj)
goto done;
}
else {
- if (_PyUnicode_CompareWithId(moduleName, &PyId_builtins) != 0) {
+ if (!_PyUnicode_EqualToASCIIId(moduleName, &PyId_builtins)) {
if (PyFile_WriteObject(moduleName, f, Py_PRINT_RAW) < 0)
goto done;
if (PyFile_WriteString(".", f) < 0)
diff --git a/Python/pythonrun.c b/Python/pythonrun.c
index 7fbf06e68a..72b6c9b060 100644
--- a/Python/pythonrun.c
+++ b/Python/pythonrun.c
@@ -747,7 +747,7 @@ print_exception(PyObject *f, PyObject *value)
err = PyFile_WriteString("<unknown>", f);
}
else {
- if (_PyUnicode_CompareWithId(moduleName, &PyId_builtins) != 0)
+ if (!_PyUnicode_EqualToASCIIId(moduleName, &PyId_builtins))
{
err = PyFile_WriteObject(moduleName, f, Py_PRINT_RAW);
err += PyFile_WriteString(".", f);