mirror of
https://github.com/MariaDB/server.git
synced 2025-01-24 15:54:37 +01:00
272 lines
7.4 KiB
C++
272 lines
7.4 KiB
C++
/* Copyright (c) 2023, MariaDB
|
|
|
|
This program is free software; you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation; version 2 of the License.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program; if not, write to the Free Software
|
|
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA */
|
|
|
|
/**
|
|
@file
|
|
|
|
@brief
|
|
This file defines all vector functions
|
|
*/
|
|
|
|
#include "item_vectorfunc.h"
|
|
#include "vector_mhnsw.h"
|
|
#include "sql_type_vector.h"
|
|
|
|
static double calc_distance_euclidean(float *v1, float *v2, size_t v_len)
|
|
{
|
|
double d= 0;
|
|
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
|
{
|
|
float dist= get_float(v1) - get_float(v2);
|
|
d+= dist * dist;
|
|
}
|
|
return sqrt(d);
|
|
}
|
|
|
|
static double calc_distance_cosine(float *v1, float *v2, size_t v_len)
|
|
{
|
|
double dotp=0, abs1=0, abs2=0;
|
|
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
|
{
|
|
float f1= get_float(v1), f2= get_float(v2);
|
|
abs1+= f1 * f1;
|
|
abs2+= f2 * f2;
|
|
dotp+= f1 * f2;
|
|
}
|
|
return 1 - dotp/sqrt(abs1*abs2);
|
|
}
|
|
|
|
Item_func_vec_distance::Item_func_vec_distance(THD *thd, Item *a, Item *b,
|
|
distance_kind kind)
|
|
:Item_real_func(thd, a, b), kind(kind)
|
|
{
|
|
}
|
|
|
|
bool Item_func_vec_distance::fix_length_and_dec(THD *thd)
|
|
{
|
|
switch (kind) {
|
|
case EUCLIDEAN: calc_distance= calc_distance_euclidean; break;
|
|
case COSINE: calc_distance= calc_distance_cosine; break;
|
|
case AUTO:
|
|
for (uint i=0; i < 2; i++)
|
|
if (auto *item= dynamic_cast<Item_field*>(args[i]->real_item()))
|
|
{
|
|
TABLE_SHARE *share= item->field->orig_table->s;
|
|
Field *f= share->field[item->field->field_index];
|
|
KEY *kinfo= share->key_info;
|
|
for (uint j= share->keys; j < share->total_keys; j++)
|
|
if (kinfo[j].algorithm == HA_KEY_ALG_VECTOR && f->key_start.is_set(j))
|
|
{
|
|
kind= mhnsw_uses_distance(f->table, kinfo + j);
|
|
return fix_length_and_dec(thd);
|
|
}
|
|
}
|
|
my_error(ER_VEC_DISTANCE_TYPE, MYF(0));
|
|
return 1;
|
|
}
|
|
set_maybe_null(); // if wrong dimensions
|
|
return Item_real_func::fix_length_and_dec(thd);
|
|
}
|
|
|
|
key_map Item_func_vec_distance::part_of_sortkey() const
|
|
{
|
|
key_map map(0);
|
|
if (Item_field *item= get_field_arg())
|
|
{
|
|
Field *f= item->field;
|
|
KEY *keyinfo= f->table->s->key_info;
|
|
for (uint i= f->table->s->keys; i < f->table->s->total_keys; i++)
|
|
if (keyinfo[i].algorithm == HA_KEY_ALG_VECTOR && f->key_start.is_set(i)
|
|
&& mhnsw_uses_distance(f->table, keyinfo + i) == kind)
|
|
map.set_bit(i);
|
|
}
|
|
return map;
|
|
}
|
|
|
|
double Item_func_vec_distance::val_real()
|
|
{
|
|
String *r1= args[0]->val_str();
|
|
String *r2= args[1]->val_str();
|
|
null_value= !r1 || !r2 || r1->length() != r2->length() ||
|
|
r1->length() % sizeof(float);
|
|
if (null_value)
|
|
return 0;
|
|
float *v1= (float *) r1->ptr();
|
|
float *v2= (float *) r2->ptr();
|
|
return calc_distance(v1, v2, (r1->length()) / sizeof(float));
|
|
}
|
|
|
|
bool Item_func_vec_totext::fix_length_and_dec(THD *thd)
|
|
{
|
|
decimals= 0;
|
|
max_length= ((args[0]->max_length / 4) *
|
|
(MAX_FLOAT_STR_LENGTH + 1 /* comma */)) + 2 /* braces */;
|
|
fix_length_and_charset(max_length, default_charset());
|
|
set_maybe_null();
|
|
return false;
|
|
}
|
|
|
|
String *Item_func_vec_totext::val_str_ascii(String *str)
|
|
{
|
|
String *r1= args[0]->val_str();
|
|
if ((null_value= args[0]->null_value))
|
|
return nullptr;
|
|
|
|
// Wrong size returns null
|
|
if (r1->length() % 4)
|
|
{
|
|
THD *thd= current_thd;
|
|
push_warning_printf(thd, Sql_condition::WARN_LEVEL_WARN,
|
|
ER_VECTOR_BINARY_FORMAT_INVALID,
|
|
ER_THD(thd, ER_VECTOR_BINARY_FORMAT_INVALID));
|
|
null_value= true;
|
|
return nullptr;
|
|
}
|
|
|
|
str->length(0);
|
|
str->set_charset(&my_charset_numeric);
|
|
str->reserve(r1->length() / 4 * (MAX_FLOAT_STR_LENGTH + 1) + 2);
|
|
str->append('[');
|
|
const char *ptr= r1->ptr();
|
|
for (size_t i= 0; i < r1->length(); i+= 4)
|
|
{
|
|
float val= get_float(ptr);
|
|
if (std::isinf(val))
|
|
if (val < 0)
|
|
str->append(STRING_WITH_LEN("-Inf"));
|
|
else
|
|
str->append(STRING_WITH_LEN("Inf"));
|
|
else if (std::isnan(val))
|
|
str->append(STRING_WITH_LEN("NaN"));
|
|
else
|
|
{
|
|
char buf[MAX_FLOAT_STR_LENGTH+1];
|
|
size_t l= my_gcvt(val, MY_GCVT_ARG_FLOAT, MAX_FLOAT_STR_LENGTH, buf, 0);
|
|
str->append(buf, l);
|
|
}
|
|
|
|
ptr+= 4;
|
|
if (r1->length() - i > 4)
|
|
str->append(',');
|
|
}
|
|
str->append(']');
|
|
|
|
return str;
|
|
}
|
|
|
|
Item_func_vec_totext::Item_func_vec_totext(THD *thd, Item *a)
|
|
: Item_str_ascii_checksum_func(thd, a)
|
|
{
|
|
}
|
|
|
|
Item_func_vec_fromtext::Item_func_vec_fromtext(THD *thd, Item *a)
|
|
: Item_str_func(thd, a)
|
|
{
|
|
}
|
|
|
|
bool Item_func_vec_fromtext::fix_length_and_dec(THD *thd)
|
|
{
|
|
decimals= 0;
|
|
/* Worst case scenario, for a valid input we have a string of the form:
|
|
[1,2,3,4,5,...] single digit numbers.
|
|
This means we can have (max_length - 1) / 2 floats.
|
|
Each float takes 4 bytes, so we do (max_length - 1) * 2. */
|
|
fix_length_and_charset((args[0]->max_length - 1) * 2, &my_charset_bin);
|
|
set_maybe_null();
|
|
return false;
|
|
}
|
|
|
|
String *Item_func_vec_fromtext::val_str(String *buf)
|
|
{
|
|
json_engine_t je;
|
|
bool end_ok= false;
|
|
String *value = args[0]->val_json(&tmp_js);
|
|
|
|
if ((null_value= !value))
|
|
return nullptr;
|
|
|
|
buf->length(0);
|
|
CHARSET_INFO *cs= value->charset();
|
|
const uchar *start= reinterpret_cast<const uchar *>(value->ptr());
|
|
const uchar *end= start + value->length();
|
|
|
|
if (json_scan_start(&je, cs, start, end) ||
|
|
json_read_value(&je))
|
|
goto error;
|
|
|
|
if (je.value_type != JSON_VALUE_ARRAY)
|
|
goto error_format;
|
|
|
|
/* Accept only arrays of floats. */
|
|
do {
|
|
switch (je.state)
|
|
{
|
|
case JST_ARRAY_START:
|
|
continue;
|
|
case JST_ARRAY_END:
|
|
end_ok = true;
|
|
break;
|
|
case JST_VALUE:
|
|
{
|
|
if (json_read_value(&je))
|
|
goto error;
|
|
|
|
if (je.value_type != JSON_VALUE_NUMBER)
|
|
goto error_format;
|
|
|
|
int error;
|
|
char *start= (char *)je.value_begin, *end;
|
|
float f= (float)cs->strntod(start, je.value_len, &end, &error);
|
|
if (unlikely(error))
|
|
goto error_format;
|
|
|
|
char f_bin[4];
|
|
float4store(f_bin, f);
|
|
buf->append(f_bin, sizeof(f_bin));
|
|
break;
|
|
}
|
|
default:
|
|
goto error_format;
|
|
}
|
|
} while (json_scan_next(&je) == 0);
|
|
|
|
if (!end_ok)
|
|
goto error_format;
|
|
|
|
if (Type_handler_vector::is_valid(buf->ptr(), buf->length()))
|
|
return buf;
|
|
|
|
null_value= true;
|
|
push_warning_printf(current_thd, Sql_condition::WARN_LEVEL_WARN,
|
|
ER_TRUNCATED_WRONG_VALUE, ER(ER_TRUNCATED_WRONG_VALUE),
|
|
"vector", value->c_ptr_safe());
|
|
return nullptr;
|
|
|
|
error_format:
|
|
{
|
|
int position= (int)((const char *) je.s.c_str - value->ptr());
|
|
null_value= true;
|
|
push_warning_printf(current_thd, Sql_condition::WARN_LEVEL_WARN,
|
|
ER_VECTOR_FORMAT_INVALID, ER(ER_VECTOR_FORMAT_INVALID),
|
|
position, value->c_ptr_safe());
|
|
return nullptr;
|
|
}
|
|
|
|
error:
|
|
report_json_error_ex(value->ptr(), &je, func_name(),
|
|
0, Sql_condition::WARN_LEVEL_WARN);
|
|
null_value= true;
|
|
return nullptr;
|
|
}
|