mariadb/plugin/win_auth_client/common.cc
2019-07-26 22:42:35 +02:00

510 lines
13 KiB
C++

/* Copyright (c) 2011, 2019, Oracle and/or its affiliates. All rights reserved.
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; version 2 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335 USA */
#include "common.h"
#include <sddl.h> // for ConvertSidToStringSid()
#include <secext.h> // for GetUserNameEx()
template <> void error_log_print<error_log_level::INFO>(const char *fmt, ...);
template <> void error_log_print<error_log_level::WARNING>(const char *fmt, ...);
template <> void error_log_print<error_log_level::ERROR>(const char *fmt, ...);
/**
Option indicating desired level of logging. Values:
0 - no logging
1 - log only error messages
2 - additionally log warnings
3 - additionally log info notes
4 - also log debug messages
Value of this option should be taken into account in the
implementation of error_log_vprint() function (see
log_client.cc).
Note: No error or debug messages are logged in production code
(see logging macros in common.h).
*/
int opt_auth_win_log_level= 2;
/** Connection class **************************************************/
/**
Create connection out of an active MYSQL_PLUGIN_VIO object.
@param[in] vio pointer to a @c MYSQL_PLUGIN_VIO object used for
connection - it can not be NULL
*/
Connection::Connection(MYSQL_PLUGIN_VIO *vio): m_vio(vio), m_error(0)
{
DBUG_ASSERT(vio);
}
/**
Write data to the connection.
@param[in] blob data to be written
@return 0 on success, VIO error code on failure.
@note In case of error, VIO error code is stored in the connection object
and can be obtained with @c error() method.
*/
int Connection::write(const Blob &blob)
{
m_error= m_vio->write_packet(m_vio, blob.ptr(), (int)blob.len());
#ifndef DBUG_OFF
if (m_error)
DBUG_PRINT("error", ("vio write error %d", m_error));
#endif
return m_error;
}
/**
Read data from connection.
@return A Blob containing read packet or null Blob in case of error.
@note In case of error, VIO error code is stored in the connection object
and can be obtained with @c error() method.
*/
Blob Connection::read()
{
unsigned char *ptr;
int len= m_vio->read_packet(m_vio, &ptr);
if (len < 0)
{
m_error= true;
return Blob();
}
return Blob(ptr, len);
}
/** Sid class *****************************************************/
/**
Create Sid object corresponding to a given account name.
@param[in] account_name name of a Windows account
The account name can be in any form accepted by @c LookupAccountName()
function.
@note In case of errors created object is invalid and its @c is_valid()
method returns @c false.
*/
Sid::Sid(const wchar_t *account_name): m_data(NULL)
#ifndef DBUG_OFF
, m_as_string(NULL)
#endif
{
DWORD sid_size= 0, domain_size= 0;
bool success;
// Determine required buffer sizes
success= LookupAccountNameW(NULL, account_name, NULL, &sid_size,
NULL, &domain_size, &m_type);
if (!success && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
{
#ifndef DBUG_OFF
Error_message_buf error_buf;
DBUG_PRINT("error", ("Could not determine SID buffer size, "
"LookupAccountName() failed with error %X (%s)",
GetLastError(), get_last_error_message(error_buf)));
#endif
return;
}
// Query for SID (domain is ignored)
wchar_t *domain= new wchar_t[domain_size];
m_data= (TOKEN_USER*) new BYTE[sid_size + sizeof(TOKEN_USER)];
m_data->User.Sid= (BYTE*)m_data + sizeof(TOKEN_USER);
success= LookupAccountNameW(NULL, account_name,
m_data->User.Sid, &sid_size,
domain, &domain_size,
&m_type);
if (!success || !is_valid())
{
#ifndef DBUG_OFF
Error_message_buf error_buf;
DBUG_PRINT("error", ("Could not determine SID of '%S', "
"LookupAccountName() failed with error %X (%s)",
account_name, GetLastError(),
get_last_error_message(error_buf)));
#endif
goto fail;
}
goto end;
fail:
if (m_data)
delete [] m_data;
m_data= NULL;
end:
if (domain)
delete [] domain;
}
/**
Create Sid object corresponding to a given security token.
@param[in] token security token of a Windows account
@note In case of errors created object is invalid and its @c is_valid()
method returns @c false.
*/
Sid::Sid(HANDLE token): m_data(NULL)
#ifndef DBUG_OFF
, m_as_string(NULL)
#endif
{
DWORD req_size= 0;
bool success;
// Determine required buffer size
success= GetTokenInformation(token, TokenUser, NULL, 0, &req_size);
if (!success && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
{
#ifndef DBUG_OFF
Error_message_buf error_buf;
DBUG_PRINT("error", ("Could not determine SID buffer size, "
"GetTokenInformation() failed with error %X (%s)",
GetLastError(), get_last_error_message(error_buf)));
#endif
return;
}
m_data= (TOKEN_USER*) new BYTE[req_size];
success= GetTokenInformation(token, TokenUser, m_data, req_size, &req_size);
if (!success || !is_valid())
{
delete [] m_data;
m_data= NULL;
#ifndef DBUG_OFF
if (!success)
{
Error_message_buf error_buf;
DBUG_PRINT("error", ("Could not read SID from security token, "
"GetTokenInformation() failed with error %X (%s)",
GetLastError(), get_last_error_message(error_buf)));
}
#endif
}
}
Sid::~Sid()
{
if (m_data)
delete [] m_data;
#ifndef DBUG_OFF
if (m_as_string)
LocalFree(m_as_string);
#endif
}
/// Check if Sid object is valid.
bool Sid::is_valid(void) const
{
return m_data && m_data->User.Sid && IsValidSid(m_data->User.Sid);
}
#ifndef DBUG_OFF
/**
Produces string representation of the SID.
@return String representation of the SID or NULL in case of errors.
@note Memory allocated for the string is automatically freed in Sid's
destructor.
*/
const char* Sid::as_string()
{
if (!m_data)
return NULL;
if (!m_as_string)
{
bool success= ConvertSidToStringSid(m_data->User.Sid, &m_as_string);
if (!success)
{
#ifndef DBUG_OFF
Error_message_buf error_buf;
DBUG_PRINT("error", ("Could not get textual representation of a SID, "
"ConvertSidToStringSid() failed with error %X (%s)",
GetLastError(), get_last_error_message(error_buf)));
#endif
m_as_string= NULL;
return NULL;
}
}
return m_as_string;
}
#endif
bool Sid::operator ==(const Sid &other)
{
if (!is_valid() || !other.is_valid())
return false;
return EqualSid(m_data->User.Sid, other.m_data->User.Sid);
}
/** Generating User Principal Name *************************/
/**
Call Windows API functions to get UPN of the current user and store it
in internal buffer.
*/
UPN::UPN(): m_buf(NULL)
{
wchar_t buf1[MAX_SERVICE_NAME_LENGTH];
// First we try to use GetUserNameEx.
m_len= sizeof(buf1)/sizeof(wchar_t);
if (!GetUserNameExW(NameUserPrincipal, buf1, (PULONG)&m_len))
{
if (GetLastError())
{
#ifndef DBUG_OFF
Error_message_buf error_buf;
DBUG_PRINT("note", ("When determining UPN"
", GetUserNameEx() failed with error %X (%s)",
GetLastError(), get_last_error_message(error_buf)));
#endif
if (ERROR_MORE_DATA == GetLastError())
ERROR_LOG(INFO, ("Buffer overrun when determining UPN:"
" need %ul characters but have %ul",
m_len, sizeof(buf1)/sizeof(WCHAR)));
}
m_len= 0; // m_len == 0 indicates invalid UPN
return;
}
/*
UPN is stored in buf1 in wide-char format - convert it to utf8
for sending over network.
*/
m_buf= wchar_to_utf8(buf1, &m_len);
if(!m_buf)
ERROR_LOG(ERROR, ("Failed to convert UPN to utf8"));
// Note: possible error would be indicated by the fact that m_buf is NULL.
return;
}
UPN::~UPN()
{
if (m_buf)
free(m_buf);
}
/**
Convert a wide-char string to utf8 representation.
@param[in] string null-terminated wide-char string to be converted
@param[in,out] len length of the string to be converted or 0; on
return length (in bytes, excluding terminating
null character) of the converted string
If len is 0 then the length of the string will be computed by this function.
@return Pointer to a buffer containing utf8 representation or NULL in
case of error.
@note The returned buffer must be freed with @c free() call.
*/
char* wchar_to_utf8(const wchar_t *string, size_t *len)
{
char *buf= NULL;
size_t str_len= len && *len ? *len : wcslen(string);
/*
A conversion from utf8 to wchar_t will never take more than 3 bytes per
character, so a buffer of length 3 * str_len should be sufficient.
We check that assumption with an assertion later.
*/
size_t buf_len= 3 * str_len;
buf= (char*)malloc(buf_len + 1);
if (!buf)
{
DBUG_PRINT("error",("Out of memory when converting string '%S' to utf8",
string));
return NULL;
}
int res= WideCharToMultiByte(CP_UTF8, // convert to UTF-8
0, // conversion flags
string, // input buffer
(int)str_len, // its length
buf, (int)buf_len, // output buffer and its size
NULL, NULL); // default character (not used)
if (res)
{
buf[res]= '\0';
if (len)
*len= res;
return buf;
}
// res is 0 which indicates error
#ifndef DBUG_OFF
Error_message_buf error_buf;
DBUG_PRINT("error", ("Could not convert string '%S' to utf8"
", WideCharToMultiByte() failed with error %X (%s)",
string, GetLastError(),
get_last_error_message(error_buf)));
#endif
// Let's check our assumption about sufficient buffer size
DBUG_ASSERT(ERROR_INSUFFICIENT_BUFFER != GetLastError());
return NULL;
}
/**
Convert an utf8 string to a wide-char string.
@param[in] string null-terminated utf8 string to be converted
@param[in,out] len length of the string to be converted or 0; on
return length (in chars) of the converted string
If len is 0 then the length of the string will be computed by this function.
@return Pointer to a buffer containing wide-char representation or NULL in
case of error.
@note The returned buffer must be freed with @c free() call.
*/
wchar_t* utf8_to_wchar(const char *string, size_t *len)
{
size_t buf_len;
/*
Note: length (in bytes) of an utf8 string is always bigger than the
number of characters in this string. Hence a buffer of size len will
be sufficient. We add 1 for the terminating null character.
*/
buf_len= len && *len ? *len : strlen(string);
wchar_t *buf= (wchar_t*)malloc((buf_len+1)*sizeof(wchar_t));
if (!buf)
{
DBUG_PRINT("error",("Out of memory when converting utf8 string '%s'"
" to wide-char representation", string));
return NULL;
}
size_t res;
res= MultiByteToWideChar(CP_UTF8, // convert from UTF-8
0, // conversion flags
string, // input buffer
(int)buf_len, // its size
buf, (int)buf_len); // output buffer and its size
if (res)
{
buf[res]= '\0';
if (len)
*len= res;
return buf;
}
// error in MultiByteToWideChar()
#ifndef DBUG_OFF
Error_message_buf error_buf;
DBUG_PRINT("error", ("Could not convert UPN from UTF-8"
", MultiByteToWideChar() failed with error %X (%s)",
GetLastError(), get_last_error_message(error_buf)));
#endif
// Let's check our assumption about sufficient buffer size
DBUG_ASSERT(ERROR_INSUFFICIENT_BUFFER != GetLastError());
return NULL;
}
/** Error handling ****************************************************/
/**
Returns error message corresponding to the last Windows error given
by GetLastError().
@note Error message is overwritten by next call to
@c get_last_error_message().
*/
const char* get_last_error_message(Error_message_buf buf)
{
int error= GetLastError();
buf[0]= '\0';
FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM,
NULL, error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
(LPTSTR)buf, ERRMSG_BUFSIZE , NULL );
return buf;
}