gh-115999: Add free-threaded specialization for `TO_BOOL` (gh-126616)

This commit is contained in:
Donghee Na 2024-11-22 07:52:16 +09:00 committed by GitHub
parent 09c240f20c
commit 78a530a578
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 168 additions and 69 deletions

View File

@ -269,6 +269,16 @@ extern unsigned int _PyType_GetVersionForCurrentState(PyTypeObject *tp);
PyAPI_FUNC(void) _PyType_SetVersion(PyTypeObject *tp, unsigned int version); PyAPI_FUNC(void) _PyType_SetVersion(PyTypeObject *tp, unsigned int version);
PyTypeObject *_PyType_LookupByVersion(unsigned int version); PyTypeObject *_PyType_LookupByVersion(unsigned int version);
// Function pointer type for user-defined validation function that will be
// called by _PyType_Validate().
// It should return 0 if the validation is passed, otherwise it will return -1.
typedef int (*_py_validate_type)(PyTypeObject *);
// It will verify the ``ty`` through user-defined validation function ``validate``,
// and if the validation is passed, it will set the ``tp_version`` as valid
// tp_version_tag from the ``ty``.
extern int _PyType_Validate(PyTypeObject *ty, _py_validate_type validate, unsigned int *tp_version);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -1272,6 +1272,72 @@ class TestSpecializer(TestBase):
self.assert_specialized(g, "CONTAINS_OP_SET") self.assert_specialized(g, "CONTAINS_OP_SET")
self.assert_no_opcode(g, "CONTAINS_OP") self.assert_no_opcode(g, "CONTAINS_OP")
@cpython_only
@requires_specialization_ft
def test_to_bool(self):
def to_bool_bool():
true_cnt, false_cnt = 0, 0
elems = [e % 2 == 0 for e in range(100)]
for e in elems:
if e:
true_cnt += 1
else:
false_cnt += 1
self.assertEqual(true_cnt, 50)
self.assertEqual(false_cnt, 50)
to_bool_bool()
self.assert_specialized(to_bool_bool, "TO_BOOL_BOOL")
self.assert_no_opcode(to_bool_bool, "TO_BOOL")
def to_bool_int():
count = 0
for i in range(100):
if i:
count += 1
else:
count -= 1
self.assertEqual(count, 98)
to_bool_int()
self.assert_specialized(to_bool_int, "TO_BOOL_INT")
self.assert_no_opcode(to_bool_int, "TO_BOOL")
def to_bool_list():
count = 0
elems = [1, 2, 3]
while elems:
count += elems.pop()
self.assertEqual(elems, [])
self.assertEqual(count, 6)
to_bool_list()
self.assert_specialized(to_bool_list, "TO_BOOL_LIST")
self.assert_no_opcode(to_bool_list, "TO_BOOL")
def to_bool_none():
count = 0
elems = [None, None, None, None]
for e in elems:
if not e:
count += 1
self.assertEqual(count, len(elems))
to_bool_none()
self.assert_specialized(to_bool_none, "TO_BOOL_NONE")
self.assert_no_opcode(to_bool_none, "TO_BOOL")
def to_bool_str():
count = 0
elems = ["", "foo", ""]
for e in elems:
if e:
count += 1
self.assertEqual(count, 1)
to_bool_str()
self.assert_specialized(to_bool_str, "TO_BOOL_STR")
self.assert_no_opcode(to_bool_str, "TO_BOOL")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -5645,6 +5645,24 @@ _PyType_SetFlags(PyTypeObject *self, unsigned long mask, unsigned long flags)
END_TYPE_LOCK(); END_TYPE_LOCK();
} }
int
_PyType_Validate(PyTypeObject *ty, _py_validate_type validate, unsigned int *tp_version)
{
int err;
BEGIN_TYPE_LOCK();
err = validate(ty);
if (!err) {
if(assign_version_tag(_PyInterpreterState_GET(), ty)) {
*tp_version = ty->tp_version_tag;
}
else {
err = -1;
}
}
END_TYPE_LOCK();
return err;
}
static void static void
set_flags_recursive(PyTypeObject *self, unsigned long mask, unsigned long flags) set_flags_recursive(PyTypeObject *self, unsigned long mask, unsigned long flags)
{ {

View File

@ -391,7 +391,7 @@ dummy_func(
}; };
specializing op(_SPECIALIZE_TO_BOOL, (counter/1, value -- value)) { specializing op(_SPECIALIZE_TO_BOOL, (counter/1, value -- value)) {
#if ENABLE_SPECIALIZATION #if ENABLE_SPECIALIZATION_FT
if (ADAPTIVE_COUNTER_TRIGGERS(counter)) { if (ADAPTIVE_COUNTER_TRIGGERS(counter)) {
next_instr = this_instr; next_instr = this_instr;
_Py_Specialize_ToBool(value, next_instr); _Py_Specialize_ToBool(value, next_instr);
@ -399,7 +399,7 @@ dummy_func(
} }
OPCODE_DEFERRED_INC(TO_BOOL); OPCODE_DEFERRED_INC(TO_BOOL);
ADVANCE_ADAPTIVE_COUNTER(this_instr[1].counter); ADVANCE_ADAPTIVE_COUNTER(this_instr[1].counter);
#endif /* ENABLE_SPECIALIZATION */ #endif /* ENABLE_SPECIALIZATION_FT */
} }
op(_TO_BOOL, (value -- res)) { op(_TO_BOOL, (value -- res)) {
@ -435,7 +435,7 @@ dummy_func(
PyObject *value_o = PyStackRef_AsPyObjectBorrow(value); PyObject *value_o = PyStackRef_AsPyObjectBorrow(value);
EXIT_IF(!PyList_CheckExact(value_o)); EXIT_IF(!PyList_CheckExact(value_o));
STAT_INC(TO_BOOL, hit); STAT_INC(TO_BOOL, hit);
res = Py_SIZE(value_o) ? PyStackRef_True : PyStackRef_False; res = PyList_GET_SIZE(value_o) ? PyStackRef_True : PyStackRef_False;
DECREF_INPUTS(); DECREF_INPUTS();
} }

View File

@ -508,7 +508,7 @@
JUMP_TO_JUMP_TARGET(); JUMP_TO_JUMP_TARGET();
} }
STAT_INC(TO_BOOL, hit); STAT_INC(TO_BOOL, hit);
res = Py_SIZE(value_o) ? PyStackRef_True : PyStackRef_False; res = PyList_GET_SIZE(value_o) ? PyStackRef_True : PyStackRef_False;
PyStackRef_CLOSE(value); PyStackRef_CLOSE(value);
stack_pointer[-1] = res; stack_pointer[-1] = res;
break; break;

View File

@ -7758,7 +7758,7 @@
value = stack_pointer[-1]; value = stack_pointer[-1];
uint16_t counter = read_u16(&this_instr[1].cache); uint16_t counter = read_u16(&this_instr[1].cache);
(void)counter; (void)counter;
#if ENABLE_SPECIALIZATION #if ENABLE_SPECIALIZATION_FT
if (ADAPTIVE_COUNTER_TRIGGERS(counter)) { if (ADAPTIVE_COUNTER_TRIGGERS(counter)) {
next_instr = this_instr; next_instr = this_instr;
_PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_SetStackPointer(frame, stack_pointer);
@ -7768,7 +7768,7 @@
} }
OPCODE_DEFERRED_INC(TO_BOOL); OPCODE_DEFERRED_INC(TO_BOOL);
ADVANCE_ADAPTIVE_COUNTER(this_instr[1].counter); ADVANCE_ADAPTIVE_COUNTER(this_instr[1].counter);
#endif /* ENABLE_SPECIALIZATION */ #endif /* ENABLE_SPECIALIZATION_FT */
} }
/* Skip 2 cache entries */ /* Skip 2 cache entries */
// _TO_BOOL // _TO_BOOL
@ -7863,7 +7863,7 @@
PyObject *value_o = PyStackRef_AsPyObjectBorrow(value); PyObject *value_o = PyStackRef_AsPyObjectBorrow(value);
DEOPT_IF(!PyList_CheckExact(value_o), TO_BOOL); DEOPT_IF(!PyList_CheckExact(value_o), TO_BOOL);
STAT_INC(TO_BOOL, hit); STAT_INC(TO_BOOL, hit);
res = Py_SIZE(value_o) ? PyStackRef_True : PyStackRef_False; res = PyList_GET_SIZE(value_o) ? PyStackRef_True : PyStackRef_False;
PyStackRef_CLOSE(value); PyStackRef_CLOSE(value);
stack_pointer[-1] = res; stack_pointer[-1] = res;
DISPATCH(); DISPATCH();

View File

@ -2667,101 +2667,106 @@ success:
cache->counter = adaptive_counter_cooldown(); cache->counter = adaptive_counter_cooldown();
} }
#ifdef Py_STATS
static int
to_bool_fail_kind(PyObject *value)
{
if (PyByteArray_CheckExact(value)) {
return SPEC_FAIL_TO_BOOL_BYTEARRAY;
}
if (PyBytes_CheckExact(value)) {
return SPEC_FAIL_TO_BOOL_BYTES;
}
if (PyDict_CheckExact(value)) {
return SPEC_FAIL_TO_BOOL_DICT;
}
if (PyFloat_CheckExact(value)) {
return SPEC_FAIL_TO_BOOL_FLOAT;
}
if (PyMemoryView_Check(value)) {
return SPEC_FAIL_TO_BOOL_MEMORY_VIEW;
}
if (PyAnySet_CheckExact(value)) {
return SPEC_FAIL_TO_BOOL_SET;
}
if (PyTuple_CheckExact(value)) {
return SPEC_FAIL_TO_BOOL_TUPLE;
}
return SPEC_FAIL_OTHER;
}
#endif // Py_STATS
static int
check_type_always_true(PyTypeObject *ty)
{
PyNumberMethods *nb = ty->tp_as_number;
if (nb && nb->nb_bool) {
return SPEC_FAIL_TO_BOOL_NUMBER;
}
PyMappingMethods *mp = ty->tp_as_mapping;
if (mp && mp->mp_length) {
return SPEC_FAIL_TO_BOOL_MAPPING;
}
PySequenceMethods *sq = ty->tp_as_sequence;
if (sq && sq->sq_length) {
return SPEC_FAIL_TO_BOOL_SEQUENCE;
}
return 0;
}
void void
_Py_Specialize_ToBool(_PyStackRef value_o, _Py_CODEUNIT *instr) _Py_Specialize_ToBool(_PyStackRef value_o, _Py_CODEUNIT *instr)
{ {
assert(ENABLE_SPECIALIZATION); assert(ENABLE_SPECIALIZATION_FT);
assert(_PyOpcode_Caches[TO_BOOL] == INLINE_CACHE_ENTRIES_TO_BOOL); assert(_PyOpcode_Caches[TO_BOOL] == INLINE_CACHE_ENTRIES_TO_BOOL);
_PyToBoolCache *cache = (_PyToBoolCache *)(instr + 1); _PyToBoolCache *cache = (_PyToBoolCache *)(instr + 1);
PyObject *value = PyStackRef_AsPyObjectBorrow(value_o); PyObject *value = PyStackRef_AsPyObjectBorrow(value_o);
uint8_t specialized_op;
if (PyBool_Check(value)) { if (PyBool_Check(value)) {
instr->op.code = TO_BOOL_BOOL; specialized_op = TO_BOOL_BOOL;
goto success; goto success;
} }
if (PyLong_CheckExact(value)) { if (PyLong_CheckExact(value)) {
instr->op.code = TO_BOOL_INT; specialized_op = TO_BOOL_INT;
goto success; goto success;
} }
if (PyList_CheckExact(value)) { if (PyList_CheckExact(value)) {
instr->op.code = TO_BOOL_LIST; specialized_op = TO_BOOL_LIST;
goto success; goto success;
} }
if (Py_IsNone(value)) { if (Py_IsNone(value)) {
instr->op.code = TO_BOOL_NONE; specialized_op = TO_BOOL_NONE;
goto success; goto success;
} }
if (PyUnicode_CheckExact(value)) { if (PyUnicode_CheckExact(value)) {
instr->op.code = TO_BOOL_STR; specialized_op = TO_BOOL_STR;
goto success; goto success;
} }
if (PyType_HasFeature(Py_TYPE(value), Py_TPFLAGS_HEAPTYPE)) { if (PyType_HasFeature(Py_TYPE(value), Py_TPFLAGS_HEAPTYPE)) {
PyNumberMethods *nb = Py_TYPE(value)->tp_as_number; unsigned int version = 0;
if (nb && nb->nb_bool) { int err = _PyType_Validate(Py_TYPE(value), check_type_always_true, &version);
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_NUMBER); if (err < 0) {
goto failure;
}
PyMappingMethods *mp = Py_TYPE(value)->tp_as_mapping;
if (mp && mp->mp_length) {
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_MAPPING);
goto failure;
}
PySequenceMethods *sq = Py_TYPE(value)->tp_as_sequence;
if (sq && sq->sq_length) {
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_SEQUENCE);
goto failure;
}
if (!PyUnstable_Type_AssignVersionTag(Py_TYPE(value))) {
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_OUT_OF_VERSIONS); SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_OUT_OF_VERSIONS);
goto failure; goto failure;
} }
uint32_t version = type_get_version(Py_TYPE(value), TO_BOOL); else if (err > 0) {
if (version == 0) { SPECIALIZATION_FAIL(TO_BOOL, err);
goto failure; goto failure;
} }
instr->op.code = TO_BOOL_ALWAYS_TRUE;
write_u32(cache->version, version); assert(err == 0);
assert(version); assert(version);
write_u32(cache->version, version);
specialized_op = TO_BOOL_ALWAYS_TRUE;
goto success; goto success;
} }
#ifdef Py_STATS
if (PyByteArray_CheckExact(value)) { SPECIALIZATION_FAIL(TO_BOOL, to_bool_fail_kind(value));
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_BYTEARRAY);
goto failure;
}
if (PyBytes_CheckExact(value)) {
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_BYTES);
goto failure;
}
if (PyDict_CheckExact(value)) {
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_DICT);
goto failure;
}
if (PyFloat_CheckExact(value)) {
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_FLOAT);
goto failure;
}
if (PyMemoryView_Check(value)) {
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_MEMORY_VIEW);
goto failure;
}
if (PyAnySet_CheckExact(value)) {
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_SET);
goto failure;
}
if (PyTuple_CheckExact(value)) {
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_TO_BOOL_TUPLE);
goto failure;
}
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_OTHER);
#endif // Py_STATS
failure: failure:
STAT_INC(TO_BOOL, failure); unspecialize(instr);
instr->op.code = TO_BOOL;
cache->counter = adaptive_counter_backoff(cache->counter);
return; return;
success: success:
STAT_INC(TO_BOOL, success); specialize(instr, specialized_op);
cache->counter = adaptive_counter_cooldown();
} }
#ifdef Py_STATS #ifdef Py_STATS