mirror of
				https://github.com/MariaDB/server.git
				synced 2025-11-04 04:46:15 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			510 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			510 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
/* Copyright (c) 2011, 2019, Oracle and/or its affiliates. All rights reserved.
 | 
						|
 | 
						|
   This program is free software; you can redistribute it and/or modify
 | 
						|
   it under the terms of the GNU General Public License as published by
 | 
						|
   the Free Software Foundation; version 2 of the License.
 | 
						|
 | 
						|
   This program is distributed in the hope that it will be useful,
 | 
						|
   but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
						|
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
						|
   GNU General Public License for more details.
 | 
						|
 | 
						|
   You should have received a copy of the GNU General Public License
 | 
						|
   along with this program; if not, write to the Free Software
 | 
						|
   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1335  USA */
 | 
						|
 | 
						|
#include "common.h"
 | 
						|
#include <sddl.h>   // for ConvertSidToStringSid()
 | 
						|
#include <secext.h> // for GetUserNameEx()
 | 
						|
 | 
						|
 | 
						|
template <> void error_log_print<error_log_level::INFO>(const char *fmt, ...);
 | 
						|
template <> void error_log_print<error_log_level::WARNING>(const char *fmt, ...);
 | 
						|
template <> void error_log_print<error_log_level::ERROR>(const char *fmt, ...);
 | 
						|
 | 
						|
/**
 | 
						|
  Option indicating desired level of logging. Values:
 | 
						|
 | 
						|
  0 - no logging
 | 
						|
  1 - log only error messages
 | 
						|
  2 - additionally log warnings
 | 
						|
  3 - additionally log info notes
 | 
						|
  4 - also log debug messages
 | 
						|
 | 
						|
  Value of this option should be taken into account in the 
 | 
						|
  implementation of  error_log_vprint() function (see 
 | 
						|
  log_client.cc).
 | 
						|
 | 
						|
  Note: No error or debug messages are logged in production code
 | 
						|
  (see logging macros in common.h).
 | 
						|
*/
 | 
						|
int opt_auth_win_log_level= 2;
 | 
						|
 | 
						|
 | 
						|
/** Connection class **************************************************/
 | 
						|
 | 
						|
/**
 | 
						|
  Create connection out of an active MYSQL_PLUGIN_VIO object.
 | 
						|
 | 
						|
  @param[in] vio  pointer to a @c MYSQL_PLUGIN_VIO object used for
 | 
						|
                  connection - it can not be NULL
 | 
						|
*/
 | 
						|
 | 
						|
Connection::Connection(MYSQL_PLUGIN_VIO *vio): m_vio(vio), m_error(0)
 | 
						|
{
 | 
						|
  DBUG_ASSERT(vio);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/**
 | 
						|
  Write data to the connection.
 | 
						|
 | 
						|
  @param[in]  blob  data to be written
 | 
						|
 | 
						|
  @return 0 on success, VIO error code on failure.
 | 
						|
 | 
						|
  @note In case of error, VIO error code is stored in the connection object
 | 
						|
  and can be obtained with @c error() method.
 | 
						|
*/
 | 
						|
 | 
						|
int Connection::write(const Blob &blob)
 | 
						|
{
 | 
						|
  m_error= m_vio->write_packet(m_vio, blob.ptr(), (int)blob.len());
 | 
						|
 | 
						|
#ifndef DBUG_OFF
 | 
						|
  if (m_error)
 | 
						|
    DBUG_PRINT("error", ("vio write error %d", m_error));
 | 
						|
#endif
 | 
						|
 | 
						|
  return m_error;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/**
 | 
						|
  Read data from connection.
 | 
						|
 | 
						|
  @return A Blob containing read packet or null Blob in case of error.
 | 
						|
 | 
						|
  @note In case of error, VIO error code is stored in the connection object
 | 
						|
  and can be obtained with @c error() method.
 | 
						|
*/
 | 
						|
 | 
						|
Blob Connection::read()
 | 
						|
{
 | 
						|
  unsigned char *ptr;
 | 
						|
  int len= m_vio->read_packet(m_vio, &ptr);
 | 
						|
 | 
						|
  if (len < 0)
 | 
						|
  {
 | 
						|
    m_error= true;
 | 
						|
    return Blob();
 | 
						|
  }
 | 
						|
 | 
						|
  return Blob(ptr, len);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/** Sid class *****************************************************/
 | 
						|
 | 
						|
 | 
						|
/**
 | 
						|
  Create Sid object corresponding to a given account name.
 | 
						|
 | 
						|
  @param[in]  account_name  name of a Windows account
 | 
						|
 | 
						|
  The account name can be in any form accepted by @c LookupAccountName()
 | 
						|
  function.
 | 
						|
 | 
						|
  @note In case of errors created object is invalid and its @c is_valid()
 | 
						|
  method returns @c false.
 | 
						|
*/
 | 
						|
 | 
						|
Sid::Sid(const wchar_t *account_name): m_data(NULL)
 | 
						|
#ifndef DBUG_OFF
 | 
						|
, m_as_string(NULL)
 | 
						|
#endif
 | 
						|
{
 | 
						|
  DWORD sid_size= 0, domain_size= 0;
 | 
						|
  bool success;
 | 
						|
 | 
						|
  // Determine required buffer sizes
 | 
						|
 | 
						|
  success= LookupAccountNameW(NULL, account_name, NULL, &sid_size,
 | 
						|
                             NULL, &domain_size, &m_type);
 | 
						|
 | 
						|
  if (!success && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
 | 
						|
  {
 | 
						|
#ifndef DBUG_OFF
 | 
						|
    Error_message_buf error_buf;
 | 
						|
    DBUG_PRINT("error", ("Could not determine SID buffer size, "
 | 
						|
                         "LookupAccountName() failed with error %X (%s)",
 | 
						|
                         GetLastError(), get_last_error_message(error_buf)));
 | 
						|
#endif
 | 
						|
    return;
 | 
						|
  }
 | 
						|
 | 
						|
  // Query for SID (domain is ignored)
 | 
						|
 | 
						|
  wchar_t *domain= new wchar_t[domain_size];
 | 
						|
  m_data= (TOKEN_USER*) new BYTE[sid_size + sizeof(TOKEN_USER)];
 | 
						|
  m_data->User.Sid= (BYTE*)m_data + sizeof(TOKEN_USER);
 | 
						|
 | 
						|
  success= LookupAccountNameW(NULL, account_name,
 | 
						|
                             m_data->User.Sid, &sid_size,
 | 
						|
                             domain, &domain_size,
 | 
						|
                             &m_type);
 | 
						|
 | 
						|
  if (!success || !is_valid())
 | 
						|
  {
 | 
						|
#ifndef DBUG_OFF
 | 
						|
    Error_message_buf error_buf;
 | 
						|
    DBUG_PRINT("error", ("Could not determine SID of '%S', "
 | 
						|
                         "LookupAccountName() failed with error %X (%s)",
 | 
						|
                         account_name, GetLastError(),
 | 
						|
                         get_last_error_message(error_buf)));
 | 
						|
#endif
 | 
						|
    goto fail;
 | 
						|
  }
 | 
						|
 | 
						|
  goto end;
 | 
						|
 | 
						|
fail:
 | 
						|
  if (m_data)
 | 
						|
    delete [] m_data;
 | 
						|
  m_data= NULL;
 | 
						|
 | 
						|
end:
 | 
						|
  if (domain)
 | 
						|
    delete [] domain;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/**
 | 
						|
  Create Sid object corresponding to a given security token.
 | 
						|
 | 
						|
  @param[in]  token   security token of a Windows account
 | 
						|
 | 
						|
  @note In case of errors created object is invalid and its @c is_valid()
 | 
						|
  method returns @c false.
 | 
						|
*/
 | 
						|
 | 
						|
Sid::Sid(HANDLE token): m_data(NULL)
 | 
						|
#ifndef DBUG_OFF
 | 
						|
, m_as_string(NULL)
 | 
						|
#endif
 | 
						|
{
 | 
						|
  DWORD             req_size= 0;
 | 
						|
  bool              success;
 | 
						|
 | 
						|
  // Determine required buffer size
 | 
						|
 | 
						|
  success= GetTokenInformation(token, TokenUser, NULL, 0, &req_size);
 | 
						|
  if (!success && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
 | 
						|
  {
 | 
						|
#ifndef DBUG_OFF
 | 
						|
    Error_message_buf error_buf;
 | 
						|
    DBUG_PRINT("error", ("Could not determine SID buffer size, "
 | 
						|
                         "GetTokenInformation() failed with error %X (%s)",
 | 
						|
                         GetLastError(), get_last_error_message(error_buf)));
 | 
						|
#endif
 | 
						|
    return;
 | 
						|
  }
 | 
						|
 | 
						|
  m_data= (TOKEN_USER*) new BYTE[req_size];
 | 
						|
  success= GetTokenInformation(token, TokenUser, m_data, req_size, &req_size);
 | 
						|
 | 
						|
  if (!success || !is_valid())
 | 
						|
  {
 | 
						|
    delete [] m_data;
 | 
						|
    m_data= NULL;
 | 
						|
#ifndef DBUG_OFF
 | 
						|
    if (!success)
 | 
						|
    {
 | 
						|
      Error_message_buf error_buf;
 | 
						|
      DBUG_PRINT("error", ("Could not read SID from security token, "
 | 
						|
                           "GetTokenInformation() failed with error %X (%s)",
 | 
						|
                           GetLastError(), get_last_error_message(error_buf)));
 | 
						|
    }
 | 
						|
#endif
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
Sid::~Sid()
 | 
						|
{
 | 
						|
  if (m_data)
 | 
						|
    delete [] m_data;
 | 
						|
#ifndef DBUG_OFF
 | 
						|
  if (m_as_string)
 | 
						|
    LocalFree(m_as_string);
 | 
						|
#endif
 | 
						|
}
 | 
						|
 | 
						|
/// Check if Sid object is valid.
 | 
						|
bool Sid::is_valid(void) const
 | 
						|
{
 | 
						|
  return m_data && m_data->User.Sid && IsValidSid(m_data->User.Sid);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
#ifndef DBUG_OFF
 | 
						|
 | 
						|
/**
 | 
						|
  Produces string representation of the SID.
 | 
						|
 | 
						|
  @return String representation of the SID or NULL in case of errors.
 | 
						|
 | 
						|
  @note Memory allocated for the string is automatically freed in Sid's
 | 
						|
  destructor.
 | 
						|
*/
 | 
						|
 | 
						|
const char* Sid::as_string()
 | 
						|
{
 | 
						|
  if (!m_data)
 | 
						|
    return NULL;
 | 
						|
 | 
						|
  if (!m_as_string)
 | 
						|
  {
 | 
						|
    bool success= ConvertSidToStringSid(m_data->User.Sid, &m_as_string);
 | 
						|
 | 
						|
    if (!success)
 | 
						|
    {
 | 
						|
#ifndef DBUG_OFF
 | 
						|
      Error_message_buf error_buf;
 | 
						|
      DBUG_PRINT("error", ("Could not get textual representation of a SID, "
 | 
						|
                           "ConvertSidToStringSid() failed with error %X (%s)",
 | 
						|
                           GetLastError(), get_last_error_message(error_buf)));
 | 
						|
#endif
 | 
						|
      m_as_string= NULL;
 | 
						|
      return NULL;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  return m_as_string;
 | 
						|
}
 | 
						|
 | 
						|
#endif
 | 
						|
 | 
						|
 | 
						|
bool Sid::operator ==(const Sid &other)
 | 
						|
{
 | 
						|
  if (!is_valid() || !other.is_valid())
 | 
						|
    return false;
 | 
						|
 | 
						|
  return EqualSid(m_data->User.Sid, other.m_data->User.Sid);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/** Generating User Principal Name *************************/
 | 
						|
 | 
						|
/**
 | 
						|
  Call Windows API functions to get UPN of the current user and store it
 | 
						|
  in internal buffer.
 | 
						|
*/
 | 
						|
 | 
						|
UPN::UPN(): m_buf(NULL)
 | 
						|
{
 | 
						|
  wchar_t  buf1[MAX_SERVICE_NAME_LENGTH];
 | 
						|
 | 
						|
  // First we try to use GetUserNameEx.
 | 
						|
 | 
						|
  m_len= sizeof(buf1)/sizeof(wchar_t);
 | 
						|
 | 
						|
  if (!GetUserNameExW(NameUserPrincipal, buf1, (PULONG)&m_len))
 | 
						|
  {
 | 
						|
    if (GetLastError())
 | 
						|
    {
 | 
						|
#ifndef DBUG_OFF
 | 
						|
      Error_message_buf error_buf;
 | 
						|
      DBUG_PRINT("note", ("When determining UPN"
 | 
						|
                          ", GetUserNameEx() failed with error %X (%s)",
 | 
						|
                          GetLastError(), get_last_error_message(error_buf)));
 | 
						|
#endif
 | 
						|
      if (ERROR_MORE_DATA == GetLastError())
 | 
						|
        ERROR_LOG(INFO, ("Buffer overrun when determining UPN:"
 | 
						|
                         " need %ul characters but have %ul",
 | 
						|
                         m_len, sizeof(buf1)/sizeof(WCHAR)));
 | 
						|
    }
 | 
						|
 | 
						|
    m_len= 0;   // m_len == 0 indicates invalid UPN
 | 
						|
    return;
 | 
						|
  }
 | 
						|
 | 
						|
  /*
 | 
						|
    UPN is stored in buf1 in wide-char format - convert it to utf8
 | 
						|
    for sending over network.
 | 
						|
  */
 | 
						|
 | 
						|
  m_buf= wchar_to_utf8(buf1, &m_len);
 | 
						|
 | 
						|
  if(!m_buf)
 | 
						|
    ERROR_LOG(ERROR, ("Failed to convert UPN to utf8"));
 | 
						|
 | 
						|
  // Note: possible error would be indicated by the fact that m_buf is NULL.
 | 
						|
  return;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
UPN::~UPN()
 | 
						|
{
 | 
						|
  if (m_buf)
 | 
						|
    free(m_buf);
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/**
 | 
						|
  Convert a wide-char string to utf8 representation.
 | 
						|
 | 
						|
  @param[in]     string   null-terminated wide-char string to be converted
 | 
						|
  @param[in,out] len      length of the string to be converted or 0; on
 | 
						|
                          return length (in bytes, excluding terminating
 | 
						|
                          null character) of the converted string
 | 
						|
 | 
						|
  If len is 0 then the length of the string will be computed by this function.
 | 
						|
 | 
						|
  @return Pointer to a buffer containing utf8 representation or NULL in
 | 
						|
          case of error.
 | 
						|
 | 
						|
  @note The returned buffer must be freed with @c free() call.          
 | 
						|
*/
 | 
						|
 | 
						|
char* wchar_to_utf8(const wchar_t *string, size_t *len)
 | 
						|
{
 | 
						|
  char   *buf= NULL; 
 | 
						|
  size_t  str_len= len && *len ? *len : wcslen(string);
 | 
						|
 | 
						|
  /*
 | 
						|
    A conversion from utf8 to wchar_t will never take more than 3 bytes per
 | 
						|
    character, so a buffer of length 3 * str_len should be sufficient. 
 | 
						|
    We check that assumption with an assertion later.
 | 
						|
  */
 | 
						|
 | 
						|
  size_t  buf_len= 3 * str_len;
 | 
						|
 | 
						|
  buf= (char*)malloc(buf_len + 1);
 | 
						|
  if (!buf)
 | 
						|
  {
 | 
						|
    DBUG_PRINT("error",("Out of memory when converting string '%S' to utf8",
 | 
						|
                        string));
 | 
						|
    return NULL;
 | 
						|
  }
 | 
						|
 | 
						|
  int res= WideCharToMultiByte(CP_UTF8,              // convert to UTF-8
 | 
						|
                               0,                    // conversion flags
 | 
						|
                               string,               // input buffer
 | 
						|
                               (int)str_len,         // its length
 | 
						|
                               buf, (int)buf_len,         // output buffer and its size
 | 
						|
                               NULL, NULL);          // default character (not used)
 | 
						|
 | 
						|
  if (res)
 | 
						|
  {
 | 
						|
    buf[res]= '\0';
 | 
						|
    if (len)
 | 
						|
      *len= res;
 | 
						|
    return buf;
 | 
						|
  }
 | 
						|
 | 
						|
  // res is 0 which indicates error
 | 
						|
 | 
						|
#ifndef DBUG_OFF
 | 
						|
  Error_message_buf error_buf;
 | 
						|
  DBUG_PRINT("error", ("Could not convert string '%S' to utf8"
 | 
						|
                       ", WideCharToMultiByte() failed with error %X (%s)",
 | 
						|
                       string, GetLastError(), 
 | 
						|
                       get_last_error_message(error_buf)));
 | 
						|
#endif
 | 
						|
 | 
						|
  // Let's check our assumption about sufficient buffer size
 | 
						|
  DBUG_ASSERT(ERROR_INSUFFICIENT_BUFFER != GetLastError());
 | 
						|
 | 
						|
  return NULL;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/**
 | 
						|
  Convert an utf8 string to a wide-char string.
 | 
						|
 | 
						|
  @param[in]     string   null-terminated utf8 string to be converted
 | 
						|
  @param[in,out] len      length of the string to be converted or 0; on
 | 
						|
                          return length (in chars) of the converted string
 | 
						|
 | 
						|
  If len is 0 then the length of the string will be computed by this function.
 | 
						|
 | 
						|
  @return Pointer to a buffer containing wide-char representation or NULL in
 | 
						|
          case of error.
 | 
						|
 | 
						|
  @note The returned buffer must be freed with @c free() call.          
 | 
						|
*/
 | 
						|
 | 
						|
wchar_t* utf8_to_wchar(const char *string, size_t *len)
 | 
						|
{
 | 
						|
  size_t buf_len;
 | 
						|
 | 
						|
  /* 
 | 
						|
    Note: length (in bytes) of an utf8 string is always bigger than the
 | 
						|
    number of characters in this string. Hence a buffer of size len will
 | 
						|
    be sufficient. We add 1 for the terminating null character.
 | 
						|
  */
 | 
						|
 | 
						|
  buf_len= len && *len ? *len : strlen(string);
 | 
						|
  wchar_t *buf=  (wchar_t*)malloc((buf_len+1)*sizeof(wchar_t));
 | 
						|
 | 
						|
  if (!buf)
 | 
						|
  {
 | 
						|
    DBUG_PRINT("error",("Out of memory when converting utf8 string '%s'"
 | 
						|
                        " to wide-char representation", string));
 | 
						|
    return NULL;
 | 
						|
  }
 | 
						|
 | 
						|
  size_t  res;
 | 
						|
  res= MultiByteToWideChar(CP_UTF8,            // convert from UTF-8
 | 
						|
                           0,                  // conversion flags
 | 
						|
                           string,             // input buffer
 | 
						|
                           (int)buf_len,            // its size
 | 
						|
                           buf, (int)buf_len);      // output buffer and its size
 | 
						|
  if (res)
 | 
						|
  {
 | 
						|
    buf[res]= '\0';
 | 
						|
    if (len)
 | 
						|
      *len= res;
 | 
						|
    return buf;
 | 
						|
  }
 | 
						|
 | 
						|
  // error in MultiByteToWideChar()
 | 
						|
 | 
						|
#ifndef DBUG_OFF
 | 
						|
  Error_message_buf error_buf;
 | 
						|
  DBUG_PRINT("error", ("Could not convert UPN from UTF-8"
 | 
						|
                       ", MultiByteToWideChar() failed with error %X (%s)",
 | 
						|
                       GetLastError(), get_last_error_message(error_buf)));
 | 
						|
#endif
 | 
						|
 | 
						|
  // Let's check our assumption about sufficient buffer size
 | 
						|
  DBUG_ASSERT(ERROR_INSUFFICIENT_BUFFER != GetLastError());
 | 
						|
 | 
						|
  return NULL;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/** Error handling ****************************************************/
 | 
						|
 | 
						|
 | 
						|
/**
 | 
						|
  Returns error message corresponding to the last Windows error given
 | 
						|
  by GetLastError().
 | 
						|
 | 
						|
  @note Error message is overwritten by next call to
 | 
						|
  @c get_last_error_message().
 | 
						|
*/
 | 
						|
 | 
						|
const char* get_last_error_message(Error_message_buf buf)
 | 
						|
{
 | 
						|
  int error= GetLastError();
 | 
						|
 | 
						|
  buf[0]= '\0';
 | 
						|
  FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM,
 | 
						|
		NULL, error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
 | 
						|
		(LPTSTR)buf, ERRMSG_BUFSIZE , NULL );
 | 
						|
 | 
						|
  return buf;
 | 
						|
}
 |