cleanup: one Item_func_vec_distance class, not three

prepare for MDEV-35450 VEC_DISTANCE auto-detection
This commit is contained in:
Sergei Golubchik 2024-12-08 17:14:42 +01:00
parent d2ec5ec9c2
commit 528249a20a
5 changed files with 65 additions and 72 deletions

View file

@ -6258,7 +6258,8 @@ class Create_func_vec_distance_euclidean: public Create_func_arg2
{
public:
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
{ return new (thd->mem_root) Item_func_vec_distance_euclidean(thd, arg1, arg2); }
{ return new (thd->mem_root)
Item_func_vec_distance(thd, arg1, arg2, Item_func_vec_distance::EUCLIDEAN); }
static Create_func_vec_distance_euclidean s_singleton;
@ -6274,7 +6275,8 @@ class Create_func_vec_distance_cosine: public Create_func_arg2
{
public:
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
{ return new (thd->mem_root) Item_func_vec_distance_cosine(thd, arg1, arg2); }
{ return new (thd->mem_root)
Item_func_vec_distance(thd, arg1, arg2, Item_func_vec_distance::COSINE); }
static Create_func_vec_distance_cosine s_singleton;

View file

@ -24,7 +24,47 @@
#include "vector_mhnsw.h"
#include "sql_type_vector.h"
key_map Item_func_vec_distance_common::part_of_sortkey() const
static double calc_distance_euclidean(float *v1, float *v2, size_t v_len)
{
double d= 0;
for (size_t i= 0; i < v_len; i++, v1++, v2++)
{
float dist= get_float(v1) - get_float(v2);
d+= dist * dist;
}
return sqrt(d);
}
static double calc_distance_cosine(float *v1, float *v2, size_t v_len)
{
double dotp=0, abs1=0, abs2=0;
for (size_t i= 0; i < v_len; i++, v1++, v2++)
{
float f1= get_float(v1), f2= get_float(v2);
abs1+= f1 * f1;
abs2+= f2 * f2;
dotp+= f1 * f2;
}
return 1 - dotp/sqrt(abs1*abs2);
}
Item_func_vec_distance::Item_func_vec_distance(THD *thd, Item *a, Item *b,
distance_kind kind)
:Item_real_func(thd, a, b), kind(kind)
{
}
bool Item_func_vec_distance::fix_length_and_dec(THD *thd)
{
switch (kind) {
case EUCLIDEAN: calc_distance= calc_distance_euclidean; break;
case COSINE: calc_distance= calc_distance_cosine; break;
}
set_maybe_null(); // if wrong dimensions
return Item_real_func::fix_length_and_dec(thd);
}
key_map Item_func_vec_distance::part_of_sortkey() const
{
key_map map(0);
if (Item_field *item= get_field_arg())
@ -33,13 +73,13 @@ key_map Item_func_vec_distance_common::part_of_sortkey() const
KEY *keyinfo= f->table->s->key_info;
for (uint i= f->table->s->keys; i < f->table->s->total_keys; i++)
if (keyinfo[i].algorithm == HA_KEY_ALG_VECTOR && f->key_start.is_set(i)
&& mhnsw_uses_distance(f->table, keyinfo + i, this))
&& mhnsw_uses_distance(f->table, keyinfo + i) == kind)
map.set_bit(i);
}
return map;
}
double Item_func_vec_distance_common::val_real()
double Item_func_vec_distance::val_real()
{
String *r1= args[0]->val_str();
String *r2= args[1]->val_str();

View file

@ -22,7 +22,7 @@
#include "lex_string.h"
#include "item_func.h"
class Item_func_vec_distance_common: public Item_real_func
class Item_func_vec_distance: public Item_real_func
{
Item_field *get_field_arg() const
{
@ -36,16 +36,20 @@ class Item_func_vec_distance_common: public Item_real_func
{
return check_argument_types_or_binary(NULL, 0, arg_count);
}
virtual double calc_distance(float *v1, float *v2, size_t v_len) = 0;
double (*calc_distance)(float *v1, float *v2, size_t v_len);
public:
Item_func_vec_distance_common(THD *thd, Item *a, Item *b)
:Item_real_func(thd, a, b) {}
bool fix_length_and_dec(THD *thd) override
enum distance_kind { EUCLIDEAN, COSINE } kind;
Item_func_vec_distance(THD *thd, Item *a, Item *b, distance_kind kind);
LEX_CSTRING func_name_cstring() const override
{
set_maybe_null(); // if wrong dimensions
return Item_real_func::fix_length_and_dec(thd);
static LEX_CSTRING name[3]= {
{ STRING_WITH_LEN("VEC_DISTANCE_EUCLIDEAN") },
{ STRING_WITH_LEN("VEC_DISTANCE_COSINE") }
};
return name[kind];
}
bool fix_length_and_dec(THD *thd) override;
double val_real() override;
Item *get_const_arg() const
{
@ -56,60 +60,8 @@ public:
return NULL;
}
key_map part_of_sortkey() const override;
};
class Item_func_vec_distance_euclidean: public Item_func_vec_distance_common
{
double calc_distance(float *v1, float *v2, size_t v_len) override
{
double d= 0;
for (size_t i= 0; i < v_len; i++, v1++, v2++)
{
float dist= get_float(v1) - get_float(v2);
d+= dist * dist;
}
return sqrt(d);
}
public:
Item_func_vec_distance_euclidean(THD *thd, Item *a, Item *b)
:Item_func_vec_distance_common(thd, a, b) {}
LEX_CSTRING func_name_cstring() const override
{
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_DISTANCE_EUCLIDEAN") };
return name;
}
Item *do_get_copy(THD *thd) const override
{ return get_item_copy<Item_func_vec_distance_euclidean>(thd, this); }
};
class Item_func_vec_distance_cosine: public Item_func_vec_distance_common
{
double calc_distance(float *v1, float *v2, size_t v_len) override
{
double dotp=0, abs1=0, abs2=0;
for (size_t i= 0; i < v_len; i++, v1++, v2++)
{
float f1= get_float(v1), f2= get_float(v2);
abs1+= f1 * f1;
abs2+= f2 * f2;
dotp+= f1 * f2;
}
return 1 - dotp/sqrt(abs1*abs2);
}
public:
Item_func_vec_distance_cosine(THD *thd, Item *a, Item *b)
:Item_func_vec_distance_common(thd, a, b) {}
LEX_CSTRING func_name_cstring() const override
{
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_DISTANCE_COSINE") };
return name;
}
Item *do_get_copy(THD *thd) const override
{ return get_item_copy<Item_func_vec_distance_cosine>(thd, this); }
{ return get_item_copy<Item_func_vec_distance>(thd, this); }
};

View file

@ -20,7 +20,6 @@
#include "create_options.h"
#include "table_cache.h"
#include "vector_mhnsw.h"
#include "item_vectorfunc.h"
#include <scope.h>
#include <my_atomic_wrapper.h>
#include "bloom_filters.h"
@ -1290,7 +1289,7 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
{
THD *thd= table->in_use;
TABLE *graph= table->hlindex;
auto *fun= static_cast<Item_func_vec_distance_common*>(dist->real_item());
auto *fun= static_cast<Item_func_vec_distance*>(dist->real_item());
DBUG_ASSERT(fun);
limit= std::min<ulonglong>(limit, max_ef);
@ -1507,11 +1506,11 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
return {s, len};
}
bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist)
Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo)
{
if (keyinfo->option_struct->metric == EUCLIDEAN)
return dynamic_cast<const Item_func_vec_distance_euclidean*>(dist) != NULL;
return dynamic_cast<const Item_func_vec_distance_cosine*>(dist) != NULL;
return Item_func_vec_distance::EUCLIDEAN;
return Item_func_vec_distance::COSINE;
}
/*

View file

@ -16,7 +16,7 @@
*/
#include <my_global.h>
#include "item.h"
#include "item_vectorfunc.h"
#include "m_string.h"
#include "structs.h"
#include "table.h"
@ -33,7 +33,7 @@ int mhnsw_read_end(TABLE *table);
int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo);
int mhnsw_delete_all(TABLE *table, KEY *keyinfo, bool truncate);
void mhnsw_free(TABLE_SHARE *share);
bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist);
Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo);
extern ha_create_table_option mhnsw_index_options[];
extern st_plugin_int *mhnsw_plugin;