mirror of
https://github.com/MariaDB/server.git
synced 2025-01-24 15:54:37 +01:00
cleanup: one Item_func_vec_distance class, not three
prepare for MDEV-35450 VEC_DISTANCE auto-detection
This commit is contained in:
parent
d2ec5ec9c2
commit
528249a20a
5 changed files with 65 additions and 72 deletions
|
@ -6258,7 +6258,8 @@ class Create_func_vec_distance_euclidean: public Create_func_arg2
|
|||
{
|
||||
public:
|
||||
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
|
||||
{ return new (thd->mem_root) Item_func_vec_distance_euclidean(thd, arg1, arg2); }
|
||||
{ return new (thd->mem_root)
|
||||
Item_func_vec_distance(thd, arg1, arg2, Item_func_vec_distance::EUCLIDEAN); }
|
||||
|
||||
static Create_func_vec_distance_euclidean s_singleton;
|
||||
|
||||
|
@ -6274,7 +6275,8 @@ class Create_func_vec_distance_cosine: public Create_func_arg2
|
|||
{
|
||||
public:
|
||||
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
|
||||
{ return new (thd->mem_root) Item_func_vec_distance_cosine(thd, arg1, arg2); }
|
||||
{ return new (thd->mem_root)
|
||||
Item_func_vec_distance(thd, arg1, arg2, Item_func_vec_distance::COSINE); }
|
||||
|
||||
static Create_func_vec_distance_cosine s_singleton;
|
||||
|
||||
|
|
|
@ -24,7 +24,47 @@
|
|||
#include "vector_mhnsw.h"
|
||||
#include "sql_type_vector.h"
|
||||
|
||||
key_map Item_func_vec_distance_common::part_of_sortkey() const
|
||||
static double calc_distance_euclidean(float *v1, float *v2, size_t v_len)
|
||||
{
|
||||
double d= 0;
|
||||
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
||||
{
|
||||
float dist= get_float(v1) - get_float(v2);
|
||||
d+= dist * dist;
|
||||
}
|
||||
return sqrt(d);
|
||||
}
|
||||
|
||||
static double calc_distance_cosine(float *v1, float *v2, size_t v_len)
|
||||
{
|
||||
double dotp=0, abs1=0, abs2=0;
|
||||
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
||||
{
|
||||
float f1= get_float(v1), f2= get_float(v2);
|
||||
abs1+= f1 * f1;
|
||||
abs2+= f2 * f2;
|
||||
dotp+= f1 * f2;
|
||||
}
|
||||
return 1 - dotp/sqrt(abs1*abs2);
|
||||
}
|
||||
|
||||
Item_func_vec_distance::Item_func_vec_distance(THD *thd, Item *a, Item *b,
|
||||
distance_kind kind)
|
||||
:Item_real_func(thd, a, b), kind(kind)
|
||||
{
|
||||
}
|
||||
|
||||
bool Item_func_vec_distance::fix_length_and_dec(THD *thd)
|
||||
{
|
||||
switch (kind) {
|
||||
case EUCLIDEAN: calc_distance= calc_distance_euclidean; break;
|
||||
case COSINE: calc_distance= calc_distance_cosine; break;
|
||||
}
|
||||
set_maybe_null(); // if wrong dimensions
|
||||
return Item_real_func::fix_length_and_dec(thd);
|
||||
}
|
||||
|
||||
key_map Item_func_vec_distance::part_of_sortkey() const
|
||||
{
|
||||
key_map map(0);
|
||||
if (Item_field *item= get_field_arg())
|
||||
|
@ -33,13 +73,13 @@ key_map Item_func_vec_distance_common::part_of_sortkey() const
|
|||
KEY *keyinfo= f->table->s->key_info;
|
||||
for (uint i= f->table->s->keys; i < f->table->s->total_keys; i++)
|
||||
if (keyinfo[i].algorithm == HA_KEY_ALG_VECTOR && f->key_start.is_set(i)
|
||||
&& mhnsw_uses_distance(f->table, keyinfo + i, this))
|
||||
&& mhnsw_uses_distance(f->table, keyinfo + i) == kind)
|
||||
map.set_bit(i);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
double Item_func_vec_distance_common::val_real()
|
||||
double Item_func_vec_distance::val_real()
|
||||
{
|
||||
String *r1= args[0]->val_str();
|
||||
String *r2= args[1]->val_str();
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "lex_string.h"
|
||||
#include "item_func.h"
|
||||
|
||||
class Item_func_vec_distance_common: public Item_real_func
|
||||
class Item_func_vec_distance: public Item_real_func
|
||||
{
|
||||
Item_field *get_field_arg() const
|
||||
{
|
||||
|
@ -36,16 +36,20 @@ class Item_func_vec_distance_common: public Item_real_func
|
|||
{
|
||||
return check_argument_types_or_binary(NULL, 0, arg_count);
|
||||
}
|
||||
virtual double calc_distance(float *v1, float *v2, size_t v_len) = 0;
|
||||
double (*calc_distance)(float *v1, float *v2, size_t v_len);
|
||||
|
||||
public:
|
||||
Item_func_vec_distance_common(THD *thd, Item *a, Item *b)
|
||||
:Item_real_func(thd, a, b) {}
|
||||
bool fix_length_and_dec(THD *thd) override
|
||||
enum distance_kind { EUCLIDEAN, COSINE } kind;
|
||||
Item_func_vec_distance(THD *thd, Item *a, Item *b, distance_kind kind);
|
||||
LEX_CSTRING func_name_cstring() const override
|
||||
{
|
||||
set_maybe_null(); // if wrong dimensions
|
||||
return Item_real_func::fix_length_and_dec(thd);
|
||||
static LEX_CSTRING name[3]= {
|
||||
{ STRING_WITH_LEN("VEC_DISTANCE_EUCLIDEAN") },
|
||||
{ STRING_WITH_LEN("VEC_DISTANCE_COSINE") }
|
||||
};
|
||||
return name[kind];
|
||||
}
|
||||
bool fix_length_and_dec(THD *thd) override;
|
||||
double val_real() override;
|
||||
Item *get_const_arg() const
|
||||
{
|
||||
|
@ -56,60 +60,8 @@ public:
|
|||
return NULL;
|
||||
}
|
||||
key_map part_of_sortkey() const override;
|
||||
};
|
||||
|
||||
|
||||
class Item_func_vec_distance_euclidean: public Item_func_vec_distance_common
|
||||
{
|
||||
double calc_distance(float *v1, float *v2, size_t v_len) override
|
||||
{
|
||||
double d= 0;
|
||||
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
||||
{
|
||||
float dist= get_float(v1) - get_float(v2);
|
||||
d+= dist * dist;
|
||||
}
|
||||
return sqrt(d);
|
||||
}
|
||||
|
||||
public:
|
||||
Item_func_vec_distance_euclidean(THD *thd, Item *a, Item *b)
|
||||
:Item_func_vec_distance_common(thd, a, b) {}
|
||||
LEX_CSTRING func_name_cstring() const override
|
||||
{
|
||||
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_DISTANCE_EUCLIDEAN") };
|
||||
return name;
|
||||
}
|
||||
Item *do_get_copy(THD *thd) const override
|
||||
{ return get_item_copy<Item_func_vec_distance_euclidean>(thd, this); }
|
||||
};
|
||||
|
||||
|
||||
class Item_func_vec_distance_cosine: public Item_func_vec_distance_common
|
||||
{
|
||||
double calc_distance(float *v1, float *v2, size_t v_len) override
|
||||
{
|
||||
double dotp=0, abs1=0, abs2=0;
|
||||
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
||||
{
|
||||
float f1= get_float(v1), f2= get_float(v2);
|
||||
abs1+= f1 * f1;
|
||||
abs2+= f2 * f2;
|
||||
dotp+= f1 * f2;
|
||||
}
|
||||
return 1 - dotp/sqrt(abs1*abs2);
|
||||
}
|
||||
|
||||
public:
|
||||
Item_func_vec_distance_cosine(THD *thd, Item *a, Item *b)
|
||||
:Item_func_vec_distance_common(thd, a, b) {}
|
||||
LEX_CSTRING func_name_cstring() const override
|
||||
{
|
||||
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_DISTANCE_COSINE") };
|
||||
return name;
|
||||
}
|
||||
Item *do_get_copy(THD *thd) const override
|
||||
{ return get_item_copy<Item_func_vec_distance_cosine>(thd, this); }
|
||||
{ return get_item_copy<Item_func_vec_distance>(thd, this); }
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "create_options.h"
|
||||
#include "table_cache.h"
|
||||
#include "vector_mhnsw.h"
|
||||
#include "item_vectorfunc.h"
|
||||
#include <scope.h>
|
||||
#include <my_atomic_wrapper.h>
|
||||
#include "bloom_filters.h"
|
||||
|
@ -1290,7 +1289,7 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
|
|||
{
|
||||
THD *thd= table->in_use;
|
||||
TABLE *graph= table->hlindex;
|
||||
auto *fun= static_cast<Item_func_vec_distance_common*>(dist->real_item());
|
||||
auto *fun= static_cast<Item_func_vec_distance*>(dist->real_item());
|
||||
DBUG_ASSERT(fun);
|
||||
|
||||
limit= std::min<ulonglong>(limit, max_ef);
|
||||
|
@ -1507,11 +1506,11 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
|
|||
return {s, len};
|
||||
}
|
||||
|
||||
bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist)
|
||||
Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo)
|
||||
{
|
||||
if (keyinfo->option_struct->metric == EUCLIDEAN)
|
||||
return dynamic_cast<const Item_func_vec_distance_euclidean*>(dist) != NULL;
|
||||
return dynamic_cast<const Item_func_vec_distance_cosine*>(dist) != NULL;
|
||||
return Item_func_vec_distance::EUCLIDEAN;
|
||||
return Item_func_vec_distance::COSINE;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
*/
|
||||
|
||||
#include <my_global.h>
|
||||
#include "item.h"
|
||||
#include "item_vectorfunc.h"
|
||||
#include "m_string.h"
|
||||
#include "structs.h"
|
||||
#include "table.h"
|
||||
|
@ -33,7 +33,7 @@ int mhnsw_read_end(TABLE *table);
|
|||
int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo);
|
||||
int mhnsw_delete_all(TABLE *table, KEY *keyinfo, bool truncate);
|
||||
void mhnsw_free(TABLE_SHARE *share);
|
||||
bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist);
|
||||
Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo);
|
||||
|
||||
extern ha_create_table_option mhnsw_index_options[];
|
||||
extern st_plugin_int *mhnsw_plugin;
|
||||
|
|
Loading…
Add table
Reference in a new issue