libstdc++: Fix std::codecvt<wchar_t, char, mbstate_t> for empty dest [PR37475]

For the GNU locale model, codecvt::do_out and codecvt::do_in incorrectly
return 'ok' when the destination range is empty. That happens because
detecting incomplete output is done in the loop body, and the loop is
never even entered if to == to_end.

By restructuring the loop condition so that we check the output range
separately, we can ensure that for a non-empty source range, we always
enter the loop at least once, and detect if the destination range is too
small.

The loops also seem easier to reason about if we return immediately on
any error, instead of checking the result twice on every iteration. We
can use an RAII type to restore the locale before returning, which also
simplifies all the other member functions.

libstdc++-v3/ChangeLog:

	PR libstdc++/37475
	* config/locale/gnu/codecvt_members.cc (Guard): New RAII type.
	(do_out, do_in): Return partial if the destination is empty but
	the source is not. Use Guard to restore locale on scope exit.
	Return immediately on any conversion error.
	(do_encoding, do_max_length, do_length): Use Guard.
	* testsuite/22_locale/codecvt/in/char/37475.cc: New test.
	* testsuite/22_locale/codecvt/in/wchar_t/37475.cc: New test.
	* testsuite/22_locale/codecvt/out/char/37475.cc: New test.
	* testsuite/22_locale/codecvt/out/wchar_t/37475.cc: New test.
This commit is contained in:
Jonathan Wakely 2024-06-11 16:45:43 +01:00 committed by Jonathan Wakely
parent 95faa1bea7
commit 73ad57c244
No known key found for this signature in database
5 changed files with 142 additions and 67 deletions

View File

@ -37,8 +37,23 @@ namespace std _GLIBCXX_VISIBILITY(default)
{
_GLIBCXX_BEGIN_NAMESPACE_VERSION
// Specializations.
#ifdef _GLIBCXX_USE_WCHAR_T
namespace
{
// RAII type for changing and restoring the current thread's locale.
struct Guard
{
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
explicit Guard(__c_locale loc) : old(__uselocale(loc)) { }
~Guard() { __uselocale(old); }
#else
explicit Guard(__c_locale) { }
#endif
__c_locale old;
};
}
// Specializations.
codecvt_base::result
codecvt<wchar_t, char, mbstate_t>::
do_out(state_type& __state, const intern_type* __from,
@ -46,22 +61,21 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
extern_type* __to, extern_type* __to_end,
extern_type*& __to_next) const
{
result __ret = ok;
state_type __tmp_state(__state);
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__c_locale __old = __uselocale(_M_c_locale_codecvt);
#endif
Guard g(_M_c_locale_codecvt);
// wcsnrtombs is *very* fast but stops if encounters NUL characters:
// in case we fall back to wcrtomb and then continue, in a loop.
// NB: wcsnrtombs is a GNU extension
for (__from_next = __from, __to_next = __to;
__from_next < __from_end && __to_next < __to_end
&& __ret == ok;)
__from_next = __from;
__to_next = __to;
while (__from_next < __from_end)
{
const intern_type* __from_chunk_end = wmemchr(__from_next, L'\0',
__from_end - __from_next);
if (__to_next >= __to_end)
return partial;
const intern_type* __from_chunk_end
= wmemchr(__from_next, L'\0', __from_end - __from_next);
if (!__from_chunk_end)
__from_chunk_end = __from_end;
@ -77,12 +91,12 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
for (; __from < __from_next; ++__from)
__to_next += wcrtomb(__to_next, *__from, &__tmp_state);
__state = __tmp_state;
__ret = error;
return error;
}
else if (__from_next && __from_next < __from_chunk_end)
{
__to_next += __conv;
__ret = partial;
return partial;
}
else
{
@ -90,13 +104,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
__to_next += __conv;
}
if (__from_next < __from_end && __ret == ok)
if (__from_next < __from_end)
{
extern_type __buf[MB_LEN_MAX];
__tmp_state = __state;
const size_t __conv2 = wcrtomb(__buf, *__from_next, &__tmp_state);
if (__conv2 > static_cast<size_t>(__to_end - __to_next))
__ret = partial;
return partial;
else
{
memcpy(__to_next, __buf, __conv2);
@ -107,11 +121,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
}
}
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__uselocale(__old);
#endif
return __ret;
return ok;
}
codecvt_base::result
@ -121,24 +131,22 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
intern_type* __to, intern_type* __to_end,
intern_type*& __to_next) const
{
result __ret = ok;
state_type __tmp_state(__state);
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__c_locale __old = __uselocale(_M_c_locale_codecvt);
#endif
Guard g(_M_c_locale_codecvt);
// mbsnrtowcs is *very* fast but stops if encounters NUL characters:
// in case we store a L'\0' and then continue, in a loop.
// NB: mbsnrtowcs is a GNU extension
for (__from_next = __from, __to_next = __to;
__from_next < __from_end && __to_next < __to_end
&& __ret == ok;)
__from_next = __from;
__to_next = __to;
while (__from_next < __from_end)
{
const extern_type* __from_chunk_end;
__from_chunk_end = static_cast<const extern_type*>(memchr(__from_next, '\0',
__from_end
- __from_next));
if (__to_next >= __to_end)
return partial;
const extern_type* __from_chunk_end
= static_cast<const extern_type*>(memchr(__from_next, '\0',
__from_end - __from_next));
if (!__from_chunk_end)
__from_chunk_end = __from_end;
@ -161,13 +169,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
}
__from_next = __from;
__state = __tmp_state;
__ret = error;
return error;
}
else if (__from_next && __from_next < __from_chunk_end)
{
// It is unclear what to return in this case (see DR 382).
__to_next += __conv;
__ret = partial;
return partial;
}
else
{
@ -175,7 +183,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
__to_next += __conv;
}
if (__from_next < __from_end && __ret == ok)
if (__from_next < __from_end)
{
if (__to_next < __to_end)
{
@ -185,48 +193,30 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
*__to_next++ = L'\0';
}
else
__ret = partial;
return partial;
}
}
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__uselocale(__old);
#endif
return __ret;
return ok;
}
int
codecvt<wchar_t, char, mbstate_t>::
do_encoding() const throw()
{
Guard g(_M_c_locale_codecvt);
// XXX This implementation assumes that the encoding is
// stateless and is either single-byte or variable-width.
int __ret = 0;
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__c_locale __old = __uselocale(_M_c_locale_codecvt);
#endif
if (MB_CUR_MAX == 1)
__ret = 1;
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__uselocale(__old);
#endif
return __ret;
return MB_CUR_MAX == 1;
}
int
codecvt<wchar_t, char, mbstate_t>::
do_max_length() const throw()
{
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__c_locale __old = __uselocale(_M_c_locale_codecvt);
#endif
Guard g(_M_c_locale_codecvt);
// XXX Probably wrong for stateful encodings.
int __ret = MB_CUR_MAX;
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__uselocale(__old);
#endif
return __ret;
return MB_CUR_MAX;
}
int
@ -236,10 +226,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
{
int __ret = 0;
state_type __tmp_state(__state);
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__c_locale __old = __uselocale(_M_c_locale_codecvt);
#endif
Guard g(_M_c_locale_codecvt);
// mbsnrtowcs is *very* fast but stops if encounters NUL characters:
// in case we advance past it and then continue, in a loop.
@ -295,10 +282,6 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
}
}
#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
__uselocale(__old);
#endif
return __ret;
}
#endif

View File

@ -0,0 +1,23 @@
#include <locale>
#include <testsuite_hooks.h>
void
test_pr37475()
{
typedef std::codecvt<char, char, std::mbstate_t> test_type;
const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
const char from = 'a';
const char* from_next;
char to = 0;
char* to_next;
std::mbstate_t st = std::mbstate_t();
std::codecvt_base::result res
= cvt.in(st, &from, &from+1, from_next, &to, &to, to_next);
VERIFY( res == std::codecvt_base::noconv );
}
int main()
{
test_pr37475();
}

View File

@ -0,0 +1,23 @@
#include <locale>
#include <testsuite_hooks.h>
void
test_pr37475()
{
typedef std::codecvt<wchar_t, char, std::mbstate_t> test_type;
const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
const char from = 'a';
const char* from_next;
wchar_t to = 0;
wchar_t* to_next;
std::mbstate_t st = std::mbstate_t();
std::codecvt_base::result res
= cvt.in(st, &from, &from+1, from_next, &to, &to, to_next);
VERIFY( res == std::codecvt_base::partial );
}
int main()
{
test_pr37475();
}

View File

@ -0,0 +1,23 @@
#include <locale>
#include <assert.h>
void
test_pr37475()
{
typedef std::codecvt<char, char, std::mbstate_t> test_type;
const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
const char from = 'a';
const char* from_next;
char to;
char* to_next;
std::mbstate_t st = std::mbstate_t();
std::codecvt_base::result res
= cvt.out(st, &from, &from+1, from_next, &to, &to, to_next);
assert( res == std::codecvt_base::noconv );
}
int main()
{
test_pr37475();
}

View File

@ -0,0 +1,23 @@
#include <locale>
#include <assert.h>
void
test_pr37475()
{
typedef std::codecvt<wchar_t, char, std::mbstate_t> test_type;
const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
const wchar_t from = L'a';
const wchar_t* from_next;
char to;
char* to_next;
std::mbstate_t st = std::mbstate_t();
std::codecvt_base::result res
= cvt.out(st, &from, &from+1, from_next, &to, &to, to_next);
assert( res == std::codecvt_base::partial );
}
int main()
{
test_pr37475();
}