Skip to content

Commit 9bcce40

Browse files
authored
Fix potential unicode conversion issues for *nix (microsoft#7506)
There were multiple issues with Unicode conversion on *nix platforms. This PR fixes issues I found with the conversion functions that were causing failures when running locally, due to issues with setting the locale. It also had incorrect behavior for emulating the MultiByteToWideChar API. This change makes the local setting thread safe and more robust to different available locales in runtime environments. I fixed some off-by-one issues related to null termination, and eliminated some extra copies caused by detecting a string length, then passing the size without the null-terminator to a function which then had to copy the input string again to guarantee null-termination. The CompilerTest::CompileWithEncodeFlagTestSource test has minor updates for clarity and an added scenario. The changed code passes the Unicode tests now without asserting across all platforms. This change should have no functional impacts, except eliminating potential double-null-termination in some cases, and catching more error conditions.
1 parent 4cc1abe commit 9bcce40

File tree

3 files changed

+148
-82
lines changed

3 files changed

+148
-82
lines changed

include/dxc/WinAdapter.h

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -916,19 +916,35 @@ unsigned int SysStringLen(const BSTR bstrString);
916916
// RAII style mechanism for setting/unsetting a locale for the specified Windows
917917
// codepage
918918
class ScopedLocale {
919-
const char *m_prevLocale;
919+
locale_t Utf8Locale = nullptr;
920+
locale_t PrevLocale = nullptr;
920921

921922
public:
922-
explicit ScopedLocale(uint32_t codePage)
923-
: m_prevLocale(setlocale(LC_ALL, nullptr)) {
924-
assert((codePage == CP_UTF8) &&
923+
explicit ScopedLocale(uint32_t CodePage) {
924+
assert((CodePage == CP_UTF8) &&
925925
"Support for Linux only handles UTF8 code pages");
926-
setlocale(LC_ALL, "en_US.UTF-8");
926+
Utf8Locale = newlocale(LC_CTYPE_MASK, "C.UTF-8", NULL);
927+
if (!Utf8Locale)
928+
Utf8Locale = newlocale(LC_CTYPE_MASK, "C.utf8", NULL);
929+
if (!Utf8Locale)
930+
Utf8Locale = newlocale(LC_CTYPE_MASK, "en_US.UTF-8", NULL);
931+
assert(Utf8Locale && "Failed to create UTF-8 locale");
932+
if (!Utf8Locale)
933+
return;
934+
PrevLocale = uselocale(Utf8Locale);
935+
assert(PrevLocale && "Failed to set locale to UTF-8");
936+
if (!PrevLocale) {
937+
freelocale(Utf8Locale);
938+
Utf8Locale = nullptr;
939+
}
927940
}
928941
~ScopedLocale() {
929-
if (m_prevLocale != nullptr) {
930-
setlocale(LC_ALL, m_prevLocale);
931-
}
942+
if (PrevLocale != nullptr)
943+
uselocale(PrevLocale);
944+
if (Utf8Locale)
945+
freelocale(Utf8Locale);
946+
PrevLocale = nullptr;
947+
Utf8Locale = nullptr;
932948
}
933949
};
934950

lib/DxcSupport/Unicode.cpp

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
2828
const char *lpMultiByteStr, int cbMultiByte,
2929
wchar_t *lpWideCharStr, int cchWideChar) {
3030

31-
if (cbMultiByte == 0) {
31+
// Check for invalid sizes or potential overflow.
32+
if (cbMultiByte == 0 || cbMultiByte < -1 || cbMultiByte == INT32_MAX ||
33+
cchWideChar < 0 || cchWideChar == INT32_MAX) {
3234
SetLastError(ERROR_INVALID_PARAMETER);
3335
return 0;
3436
}
@@ -42,18 +44,17 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
4244
++cbMultiByte;
4345
}
4446
// If zero is given as the destination size, this function should
45-
// return the required size (including the null-terminating character).
47+
// return the required size (including or excluding the null-terminating
48+
// character depending on whether the input included the null-terminator).
4649
// This is the behavior of mbstowcs when the target is null.
4750
if (cchWideChar == 0) {
4851
lpWideCharStr = nullptr;
49-
} else if (cchWideChar < cbMultiByte) {
50-
SetLastError(ERROR_INSUFFICIENT_BUFFER);
51-
return 0;
5252
}
5353

54+
ScopedLocale utf8_locale_scope(CP_UTF8);
55+
56+
bool isNullTerminated = false;
5457
size_t rv;
55-
const char *prevLocale = setlocale(LC_ALL, nullptr);
56-
setlocale(LC_ALL, "en_US.UTF-8");
5758
if (lpMultiByteStr[cbMultiByte - 1] != '\0') {
5859
char *srcStr = (char *)malloc((cbMultiByte + 1) * sizeof(char));
5960
strncpy(srcStr, lpMultiByteStr, cbMultiByte);
@@ -62,14 +63,29 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
6263
free(srcStr);
6364
} else {
6465
rv = mbstowcs(lpWideCharStr, lpMultiByteStr, cchWideChar);
66+
isNullTerminated = true;
67+
}
68+
69+
if (rv == ~(size_t)0) {
70+
// mbstowcs returns -1 on error.
71+
SetLastError(ERROR_INVALID_PARAMETER);
72+
return 0;
6573
}
6674

67-
if (prevLocale)
68-
setlocale(LC_ALL, prevLocale);
75+
// Return value of mbstowcs (rv) excludes the terminating character.
76+
// Matching MultiByteToWideChar requires returning the size written including
77+
// the null terminator if the input was null-terminated, otherwise it
78+
// returns the size written excluding the null terminator.
79+
if (isNullTerminated)
80+
rv += 1;
81+
82+
// Check for overflow when returning the size.
83+
if (rv >= INT32_MAX) {
84+
SetLastError(ERROR_INVALID_PARAMETER);
85+
return 0; // Overflow error
86+
}
6987

70-
if (rv == (size_t)cbMultiByte)
71-
return rv;
72-
return rv + 1; // mbstowcs excludes the terminating character
88+
return rv;
7389
}
7490

7591
// WideCharToMultiByte is a Windows-specific method.
@@ -84,7 +100,9 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
84100
*lpUsedDefaultChar = FALSE;
85101
}
86102

87-
if (cchWideChar == 0) {
103+
// Check for invalid sizes or potential overflow.
104+
if (cchWideChar == 0 || cchWideChar < -1 || cchWideChar > (INT32_MAX - 1) ||
105+
cbMultiByte < 0 || cbMultiByte > (INT32_MAX - 1)) {
88106
SetLastError(ERROR_INVALID_PARAMETER);
89107
return 0;
90108
}
@@ -98,18 +116,17 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
98116
++cchWideChar;
99117
}
100118
// If zero is given as the destination size, this function should
101-
// return the required size (including the null-terminating character).
119+
// return the required size (including or excluding the null-terminating
120+
// character depending on whether the input included the null-terminator).
102121
// This is the behavior of wcstombs when the target is null.
103122
if (cbMultiByte == 0) {
104123
lpMultiByteStr = nullptr;
105-
} else if (cbMultiByte < cchWideChar) {
106-
SetLastError(ERROR_INSUFFICIENT_BUFFER);
107-
return 0;
108124
}
109125

126+
ScopedLocale utf8_locale_scope(CP_UTF8);
127+
128+
bool isNullTerminated = false;
110129
size_t rv;
111-
const char *prevLocale = setlocale(LC_ALL, nullptr);
112-
setlocale(LC_ALL, "en_US.UTF-8");
113130
if (lpWideCharStr[cchWideChar - 1] != L'\0') {
114131
wchar_t *srcStr = (wchar_t *)malloc((cchWideChar + 1) * sizeof(wchar_t));
115132
wcsncpy(srcStr, lpWideCharStr, cchWideChar);
@@ -118,21 +135,41 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
118135
free(srcStr);
119136
} else {
120137
rv = wcstombs(lpMultiByteStr, lpWideCharStr, cbMultiByte);
138+
isNullTerminated = true;
139+
}
140+
141+
if (rv == ~(size_t)0) {
142+
// wcstombs returns -1 on error.
143+
SetLastError(ERROR_INVALID_PARAMETER);
144+
return 0;
121145
}
122146

123-
if (prevLocale)
124-
setlocale(LC_ALL, prevLocale);
147+
// Return value of wcstombs (rv) excludes the terminating character.
148+
// Matching MultiByteToWideChar requires returning the size written including
149+
// the null terminator if the input was null-terminated, otherwise it
150+
// returns the size written excluding the null terminator.
151+
if (isNullTerminated)
152+
rv += 1;
153+
154+
// Check for overflow when returning the size.
155+
if (rv >= INT32_MAX) {
156+
SetLastError(ERROR_INVALID_PARAMETER);
157+
return 0; // Overflow error
158+
}
125159

126-
if (rv == (size_t)cchWideChar)
127-
return rv;
128-
return rv + 1; // mbstowcs excludes the terminating character
160+
return rv;
129161
}
130162
#endif // _WIN32
131163

132164
namespace Unicode {
133165

134166
bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp,
135167
DWORD flags, std::string *pValue, bool *lossy) {
168+
DXASSERT_NOMSG(cWide == ~(size_t)0 || cWide < INT32_MAX);
169+
if (text == nullptr || pValue == nullptr || cWide == 0 ||
170+
!(cWide == ~(size_t)0 || cWide < INT32_MAX))
171+
return false;
172+
136173
BOOL usedDefaultChar;
137174
LPBOOL pUsedDefaultChar = (lossy == nullptr) ? nullptr : &usedDefaultChar;
138175
if (lossy != nullptr)
@@ -147,31 +184,37 @@ bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp,
147184
return true;
148185
}
149186

150-
int cbUTF8 = ::WideCharToMultiByte(cp, flags, text, cWide, nullptr, 0,
151-
nullptr, pUsedDefaultChar);
187+
int cbUTF8 = ::WideCharToMultiByte(cp, flags, text, static_cast<int>(cWide),
188+
nullptr, 0, nullptr, pUsedDefaultChar);
152189
if (cbUTF8 == 0)
153190
return false;
154191

155192
pValue->resize(cbUTF8);
156193

157-
cbUTF8 = ::WideCharToMultiByte(cp, flags, text, cWide, &(*pValue)[0],
158-
pValue->size(), nullptr, pUsedDefaultChar);
194+
cbUTF8 = ::WideCharToMultiByte(cp, flags, text, static_cast<int>(cWide),
195+
&(*pValue)[0], pValue->size(), nullptr,
196+
pUsedDefaultChar);
159197
DXASSERT(cbUTF8 > 0, "otherwise contents have changed");
160-
DXASSERT((*pValue)[pValue->size()] == '\0',
161-
"otherwise string didn't null-terminate after resize() call");
198+
if ((cWide == ~(size_t)0 || text[cWide - 1] == L'\0') &&
199+
(*pValue)[pValue->size() - 1] == '\0') {
200+
// When the input is null-terminated, the output includes the null
201+
// terminator. Reduce the size by 1 to remove the embedded null terminator
202+
// inside the string.
203+
pValue->resize(cbUTF8 - 1);
204+
}
162205

163206
if (lossy != nullptr)
164207
*lossy = usedDefaultChar;
165208
return true;
166209
}
167210

168211
bool UTF8ToWideString(const char *pUTF8, std::wstring *pWide) {
169-
size_t cbUTF8 = (pUTF8 == nullptr) ? 0 : strlen(pUTF8);
170-
return UTF8ToWideString(pUTF8, cbUTF8, pWide);
212+
return UTF8ToWideString(pUTF8, -1, pWide);
171213
}
172214

173215
bool UTF8ToWideString(const char *pUTF8, size_t cbUTF8, std::wstring *pWide) {
174216
DXASSERT_NOMSG(pWide != nullptr);
217+
DXASSERT_NOMSG(cbUTF8 == ~(size_t)0 || cbUTF8 < INT32_MAX);
175218

176219
// Handle zero-length as a special case; it's a special value to indicate
177220
// errors in MultiByteToWideChar.
@@ -181,17 +224,23 @@ bool UTF8ToWideString(const char *pUTF8, size_t cbUTF8, std::wstring *pWide) {
181224
}
182225

183226
int cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8,
184-
cbUTF8, nullptr, 0);
227+
static_cast<int>(cbUTF8), nullptr, 0);
185228
if (cWide == 0)
186229
return false;
187230

188231
pWide->resize(cWide);
189232

190-
cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8, cbUTF8,
191-
&(*pWide)[0], pWide->size());
233+
cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8,
234+
static_cast<int>(cbUTF8), &(*pWide)[0],
235+
pWide->size());
192236
DXASSERT(cWide > 0, "otherwise contents changed");
193-
DXASSERT((*pWide)[pWide->size()] == L'\0',
194-
"otherwise wstring didn't null-terminate after resize() call");
237+
if ((cbUTF8 == ~(size_t)0 || pUTF8[cbUTF8 - 1] == '\0') &&
238+
(*pWide)[pWide->size() - 1] == '\0') {
239+
// When the input is null-terminated, the output includes the null
240+
// terminator. Reduce the size by 1 to remove the embedded null terminator
241+
// inside the string.
242+
pWide->resize(cWide - 1);
243+
}
195244
return true;
196245
}
197246

@@ -213,11 +262,12 @@ bool UTF8ToConsoleString(const char *text, size_t textLen, std::string *pValue,
213262
if (!UTF8ToWideString(text, textLen, &text16)) {
214263
return false;
215264
}
216-
return WideToConsoleString(text16.c_str(), text16.length(), pValue, lossy);
265+
return WideToConsoleString(text16.c_str(), text16.length() + 1, pValue,
266+
lossy);
217267
}
218268

219269
bool UTF8ToConsoleString(const char *text, std::string *pValue, bool *lossy) {
220-
return UTF8ToConsoleString(text, strlen(text), pValue, lossy);
270+
return UTF8ToConsoleString(text, ~(size_t)0, pValue, lossy);
221271
}
222272

223273
bool WideToConsoleString(const wchar_t *text, size_t textLen,
@@ -230,7 +280,7 @@ bool WideToConsoleString(const wchar_t *text, size_t textLen,
230280

231281
bool WideToConsoleString(const wchar_t *text, std::string *pValue,
232282
bool *lossy) {
233-
return WideToConsoleString(text, wcslen(text), pValue, lossy);
283+
return WideToConsoleString(text, ~(size_t)0, pValue, lossy);
234284
}
235285

236286
bool WideToUTF8String(const wchar_t *pWide, size_t cWide, std::string *pUTF8) {
@@ -242,7 +292,7 @@ bool WideToUTF8String(const wchar_t *pWide, size_t cWide, std::string *pUTF8) {
242292
bool WideToUTF8String(const wchar_t *pWide, std::string *pUTF8) {
243293
DXASSERT_NOMSG(pWide != nullptr);
244294
DXASSERT_NOMSG(pUTF8 != nullptr);
245-
return WideToEncodedString(pWide, wcslen(pWide), CP_UTF8, 0, pUTF8, nullptr);
295+
return WideToEncodedString(pWide, ~(size_t)0, CP_UTF8, 0, pUTF8, nullptr);
246296
}
247297

248298
std::string WideToUTF8StringOrThrow(const wchar_t *pWide) {

tools/clang/unittests/HLSL/CompilerTest.cpp

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,13 @@ class CompilerTest : public ::testing::Test {
207207
void TestEncodingImpl(const void *sourceData, size_t sourceSize,
208208
UINT32 codePage, const void *includedData,
209209
size_t includedSize, const WCHAR *encoding = nullptr);
210+
template <typename T1, typename T2>
211+
void TestEncodingImpl(std::basic_string<T1> source, UINT32 codePage,
212+
std::basic_string<T2> included,
213+
const WCHAR *encoding = nullptr) {
214+
TestEncodingImpl(source.data(), source.size() * sizeof(T1), codePage,
215+
included.data(), included.size() * sizeof(T2), encoding);
216+
}
210217
TEST_METHOD(CompileWithEncodeFlagTestSource)
211218

212219
#if _ITERATOR_DEBUG_LEVEL == 0
@@ -3636,54 +3643,47 @@ void CompilerTest::TestEncodingImpl(const void *sourceData, size_t sourceSize,
36363643

36373644
TEST_F(CompilerTest, CompileWithEncodeFlagTestSource) {
36383645

3639-
std::string sourceUtf8 = "#include \"include.hlsl\"\r\n"
3640-
"float4 main() : SV_Target { return 0; }";
3641-
std::string includeUtf8 = "// Comment\n";
3646+
std::string SourceUtf8 = "#include \"include.hlsl\"\n"
3647+
"float4 main() : SV_Target { return Buf[0]; }";
3648+
std::string IncludeUtf8 = "Buffer<float4> Buf;\n";
36423649
std::string utf8BOM = "\xEF"
36433650
"\xBB"
36443651
"\xBF"; // UTF-8 BOM
3645-
std::string includeUtf8BOM = utf8BOM + includeUtf8;
3652+
std::string IncludeUtf8BOM = utf8BOM + IncludeUtf8;
36463653

3647-
std::wstring sourceWide = L"#include \"include.hlsl\"\r\n"
3648-
L"float4 main() : SV_Target { return 0; }";
3649-
std::wstring includeWide = L"// Comments\n";
3650-
std::wstring utf16BOM = L"\xFEFF"; // UTF-16 LE BOM
3651-
std::wstring includeUtf16BOM = utf16BOM + includeWide;
3654+
std::wstring SourceWide = L"#include \"include.hlsl\"\n"
3655+
L"float4 main() : SV_Target { return Buf[0]; }";
3656+
std::wstring IncludeWide = L"Buffer<float4> Buf;\n";
36523657

3653-
// Included files interpreted with encoding option if no BOM
3654-
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8,
3655-
includeUtf8.data(), includeUtf8.size(), L"utf8");
3658+
// Windows: UTF-16 BOM is '\xFEFF'
3659+
// *nix: UTF-32 BOM is L'\x0000FEFF'
3660+
// Thus, BOM wide character value is identical for UTF-16 and UTF-32.
3661+
// Endianess will be native, since we are using wide strings directly.
3662+
std::wstring WideBOM = L"\xFEFF";
3663+
3664+
std::wstring IncludeWideBOM = WideBOM + IncludeWide;
36563665

3657-
TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
3658-
DXC_CP_WIDE, includeWide.data(),
3659-
includeWide.size() * sizeof(L'A'), L"wide");
3666+
// Included files interpreted with encoding option if no BOM
3667+
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeUtf8, L"utf8");
3668+
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeWide, L"wide");
36603669

36613670
// Encoding option ignored if BOM present
3662-
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8,
3663-
includeUtf8BOM.data(), includeUtf8BOM.size(), L"wide");
3671+
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeUtf8BOM, L"wide");
3672+
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeWideBOM, L"utf8");
36643673

3665-
TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
3666-
DXC_CP_WIDE, includeUtf16BOM.data(),
3667-
includeUtf16BOM.size() * sizeof(L'A'), L"utf8");
3674+
// Encoding option ignored if BOM present - different encoding for source
3675+
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeUtf8BOM, L"wide");
3676+
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeWideBOM, L"utf8");
36683677

36693678
// Source file interpreted according to DxcBuffer encoding if not CP_ACP
36703679
// Included files interpreted with encoding option if no BOM
3671-
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8,
3672-
includeWide.data(), includeWide.size() * sizeof(L'A'),
3673-
L"wide");
3674-
3675-
TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
3676-
DXC_CP_WIDE, includeUtf8.data(), includeUtf8.size(),
3677-
L"utf8");
3680+
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeWide, L"wide");
3681+
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeUtf8, L"utf8");
36783682

36793683
// Source file interpreted by encoding option if source DxcBuffer encoding =
36803684
// CP_ACP (default)
3681-
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_ACP,
3682-
includeUtf8.data(), includeUtf8.size(), L"utf8");
3683-
3684-
TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
3685-
DXC_CP_ACP, includeWide.data(),
3686-
includeWide.size() * sizeof(L'A'), L"wide");
3685+
TestEncodingImpl(SourceUtf8, DXC_CP_ACP, IncludeUtf8, L"utf8");
3686+
TestEncodingImpl(SourceWide, DXC_CP_ACP, IncludeWide, L"wide");
36873687
}
36883688

36893689
TEST_F(CompilerTest, CompileWhenODumpThenOptimizerMatch) {

0 commit comments

Comments
 (0)