mirror of
https://github.com/MariaDB/server.git
synced 2025-01-16 03:52:35 +01:00
mhnsw: refactor FVector* classes
Now there's an FVector class which is a pure vector, an array of floats. It doesn't necessarily corresponds to a row in the table, and usually there is only one FVector instance - the one we're searching for. And there's an FVectorNode class, which is a node in the graph. It has a ref (identifying a row in the source table), possibly an array of floats (or not — in which case it will be read lazily from the source table as needed). There are many FVectorNodes and they're cached to avoid re-reading them from the disk.
This commit is contained in:
parent
10de659020
commit
267092d4a1
1 changed files with 151 additions and 175 deletions
|
@ -42,64 +42,31 @@ const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\
|
|||
")};
|
||||
|
||||
|
||||
class FVectorRef: public Sql_alloc
|
||||
class MHNSW_Context;
|
||||
|
||||
class FVector: public Sql_alloc
|
||||
{
|
||||
public:
|
||||
// Shallow ref copy. Used for other ref lookups in HashSet
|
||||
FVectorRef(const void *ref, size_t ref_len): ref{(uchar*)ref}, ref_len{ref_len} {}
|
||||
|
||||
static uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool)
|
||||
{
|
||||
*key_len= elem->ref_len;
|
||||
return elem->ref;
|
||||
}
|
||||
|
||||
static void free_vector(void *elem)
|
||||
{
|
||||
delete (FVectorRef *)elem;
|
||||
}
|
||||
|
||||
size_t get_ref_len() const { return ref_len; }
|
||||
uchar* get_ref() const { return ref; }
|
||||
|
||||
MHNSW_Context *ctx;
|
||||
FVector(MHNSW_Context *ctx_, const void *vec_);
|
||||
float *vec;
|
||||
protected:
|
||||
FVectorRef() = default;
|
||||
uchar *ref;
|
||||
size_t ref_len;
|
||||
FVector(MHNSW_Context *ctx_) : ctx(ctx_), vec(nullptr) {}
|
||||
};
|
||||
|
||||
class FVector: public FVectorRef
|
||||
class FVectorNode: public FVector
|
||||
{
|
||||
private:
|
||||
float *vec;
|
||||
size_t vec_len;
|
||||
uchar *ref;
|
||||
public:
|
||||
FVector(): vec(nullptr), vec_len(0) {}
|
||||
FVectorNode(MHNSW_Context *ctx_, const void *ref_);
|
||||
FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_);
|
||||
float distance_to(const FVector &other) const;
|
||||
int instantiate_vector();
|
||||
size_t get_ref_len() const;
|
||||
uchar *get_ref() const { return ref; }
|
||||
|
||||
bool init(MEM_ROOT *root, const uchar *ref_, size_t ref_len_, const void *vec_, size_t bytes)
|
||||
{
|
||||
ref= (uchar*)alloc_root(root, ref_len_ + bytes);
|
||||
if (!ref)
|
||||
return true;
|
||||
|
||||
vec= reinterpret_cast<float *>(ref + ref_len_);
|
||||
|
||||
memcpy(ref, ref_, ref_len_);
|
||||
memcpy(vec, vec_, bytes);
|
||||
|
||||
ref_len= ref_len_;
|
||||
vec_len= bytes / sizeof(float);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t size_of() const { return vec_len * sizeof(float); }
|
||||
|
||||
float distance_to(const FVector &other) const
|
||||
{
|
||||
DBUG_ASSERT(other.vec_len == vec_len);
|
||||
return euclidean_vec_distance(vec, other.vec, vec_len);
|
||||
}
|
||||
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
|
||||
};
|
||||
|
||||
class MHNSW_Context
|
||||
|
@ -108,8 +75,9 @@ class MHNSW_Context
|
|||
MEM_ROOT root;
|
||||
TABLE *table;
|
||||
Field *vec_field;
|
||||
Hash_set<FVectorRef> vector_cache{PSI_INSTRUMENT_MEM, FVectorRef::get_key};
|
||||
Hash_set<FVectorRef> vector_ref_cache{PSI_INSTRUMENT_MEM, FVectorRef::get_key};
|
||||
size_t vec_len= 0;
|
||||
|
||||
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
|
||||
|
||||
MHNSW_Context(TABLE *table, Field *vec_field)
|
||||
: table(table), vec_field(vec_field)
|
||||
|
@ -122,40 +90,67 @@ class MHNSW_Context
|
|||
free_root(&root, MYF(0));
|
||||
}
|
||||
|
||||
FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len)
|
||||
{
|
||||
FVectorRef tmp(ref, ref_len);
|
||||
FVectorRef *v= vector_ref_cache.find(&tmp);
|
||||
if (v)
|
||||
return v;
|
||||
|
||||
uchar *buf= (uchar*)memdup_root(&root, ref, ref_len);
|
||||
if ((v= new (&root) FVectorRef(buf, ref_len)))
|
||||
vector_ref_cache.insert(v);
|
||||
return v;
|
||||
}
|
||||
|
||||
FVector *get_fvector_from_source(const FVectorRef &ref)
|
||||
{
|
||||
FVectorRef *v= vector_cache.find(&ref);
|
||||
if (v)
|
||||
return (FVector *)v;
|
||||
|
||||
if (table->file->ha_rnd_pos(table->record[0], ref.get_ref()))
|
||||
return nullptr; // XXX the error code is lost
|
||||
|
||||
String buf, *vec= vec_field->val_str(&buf);
|
||||
|
||||
FVector *new_vector= new (&root) FVector;
|
||||
new_vector->init(&root, ref.get_ref(), ref.get_ref_len(), vec->ptr(), vec->length());
|
||||
|
||||
vector_cache.insert(new_vector);
|
||||
|
||||
return new_vector;
|
||||
}
|
||||
FVectorNode *get_node(const void *ref_);
|
||||
};
|
||||
|
||||
static int cmp_vec(const FVector *target, const FVector *a, const FVector *b)
|
||||
FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
|
||||
{
|
||||
vec= (float*)memdup_root(&ctx->root, vec_, ctx->vec_len * sizeof(float));
|
||||
}
|
||||
|
||||
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_)
|
||||
: FVector(ctx_)
|
||||
{
|
||||
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len());
|
||||
}
|
||||
|
||||
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_)
|
||||
: FVector(ctx_, vec_)
|
||||
{
|
||||
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len());
|
||||
}
|
||||
|
||||
float FVectorNode::distance_to(const FVector &other) const
|
||||
{
|
||||
if (!vec)
|
||||
const_cast<FVectorNode*>(this)->instantiate_vector();
|
||||
return euclidean_vec_distance(vec, other.vec, ctx->vec_len);
|
||||
}
|
||||
|
||||
int FVectorNode::instantiate_vector()
|
||||
{
|
||||
DBUG_ASSERT(vec == nullptr);
|
||||
if (int err= ctx->table->file->ha_rnd_pos(ctx->table->record[0], ref))
|
||||
return err;
|
||||
String buf, *v= ctx->vec_field->val_str(&buf);
|
||||
ctx->vec_len= v->length() / sizeof(float);
|
||||
vec= (float*)memdup_root(&ctx->root, v->ptr(), v->length());
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t FVectorNode::get_ref_len() const
|
||||
{
|
||||
return ctx->table->file->ref_length;
|
||||
}
|
||||
|
||||
uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
|
||||
{
|
||||
*key_len= elem->get_ref_len();
|
||||
return elem->ref;
|
||||
}
|
||||
|
||||
FVectorNode *MHNSW_Context::get_node(const void *ref)
|
||||
{
|
||||
FVectorNode *node= node_cache.find(ref, table->file->ref_length);
|
||||
if (!node)
|
||||
{
|
||||
node= new (&root) FVectorNode(this, ref);
|
||||
node_cache.insert(node);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNode *b)
|
||||
{
|
||||
float a_dist= a->distance_to(*target);
|
||||
float b_dist= b->distance_to(*target);
|
||||
|
@ -171,8 +166,8 @@ const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
|
|||
const bool EXTEND_CANDIDATES=true; // XXX or false?
|
||||
|
||||
static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
|
||||
const FVectorRef &source_node,
|
||||
List<FVectorRef> *neighbors)
|
||||
const FVectorNode &source_node,
|
||||
List<FVectorNode> *neighbors)
|
||||
{
|
||||
TABLE *graph= ctx->table->hlindex;
|
||||
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
|
||||
|
@ -189,18 +184,16 @@ static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
|
|||
// mhnsw_insert() guarantees that all ref have the same length
|
||||
uint ref_length= source_node.get_ref_len();
|
||||
|
||||
const uchar *neigh_arr_bytes= reinterpret_cast<const uchar *>(str->ptr());
|
||||
const char *neigh_arr_bytes= str->ptr();
|
||||
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
|
||||
if (number_of_neighbors * ref_length + HNSW_MAX_M_WIDTH != str->length())
|
||||
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
|
||||
|
||||
const uchar *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
|
||||
const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
|
||||
for (uint i= 0; i < number_of_neighbors; i++)
|
||||
{
|
||||
FVectorRef *v= ctx->get_fvector_ref(pos, ref_length);
|
||||
if (!v)
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
neighbors->push_back(v, &ctx->root);
|
||||
FVectorNode *neigh= ctx->get_node(pos);
|
||||
neighbors->push_back(neigh, &ctx->root);
|
||||
pos+= ref_length;
|
||||
}
|
||||
|
||||
|
@ -210,20 +203,20 @@ static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
|
|||
|
||||
static int select_neighbors(MHNSW_Context *ctx,
|
||||
size_t layer_number, const FVector &target,
|
||||
const List<FVectorRef> &candidates,
|
||||
const List<FVectorNode> &candidates,
|
||||
size_t max_neighbor_connections,
|
||||
List<FVectorRef> *neighbors)
|
||||
List<FVectorNode> *neighbors)
|
||||
{
|
||||
/*
|
||||
TODO: If the input neighbors list is already sorted in search_layer, then
|
||||
no need to do additional queue build steps here.
|
||||
*/
|
||||
|
||||
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, FVectorRef::get_key);
|
||||
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
|
||||
|
||||
Queue<FVector, const FVector> pq; // working queue
|
||||
Queue<FVector, const FVector> pq_discard; // queue for discarded candidates
|
||||
Queue<FVector, const FVector> best; // neighbors to return
|
||||
Queue<FVectorNode, const FVector> pq; // working queue
|
||||
Queue<FVectorNode, const FVector> pq_discard; // queue for discarded candidates
|
||||
Queue<FVectorNode, const FVector> best; // neighbors to return
|
||||
|
||||
// TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size.
|
||||
// This should not be fixed.
|
||||
|
@ -232,32 +225,26 @@ static int select_neighbors(MHNSW_Context *ctx,
|
|||
best.init(max_neighbor_connections, true, cmp_vec, &target))
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
|
||||
for (const FVectorRef &candidate : candidates)
|
||||
for (const FVectorNode &candidate : candidates)
|
||||
{
|
||||
FVector *v= ctx->get_fvector_from_source(candidate);
|
||||
if (!v)
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
visited.insert(&candidate);
|
||||
pq.push(v);
|
||||
pq.push(&candidate);
|
||||
}
|
||||
|
||||
if (EXTEND_CANDIDATES)
|
||||
{
|
||||
for (const FVectorRef &candidate : candidates)
|
||||
for (const FVectorNode &candidate : candidates)
|
||||
{
|
||||
List<FVectorRef> candidate_neighbors;
|
||||
List<FVectorNode> candidate_neighbors;
|
||||
if (int err= get_neighbors(ctx, layer_number, candidate,
|
||||
&candidate_neighbors))
|
||||
return err;
|
||||
for (const FVectorRef &extra_candidate : candidate_neighbors)
|
||||
for (const FVectorNode &extra_candidate : candidate_neighbors)
|
||||
{
|
||||
if (visited.find(&extra_candidate))
|
||||
continue;
|
||||
visited.insert(&extra_candidate);
|
||||
FVector *v= ctx->get_fvector_from_source(extra_candidate);
|
||||
if (!v)
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
pq.push(v);
|
||||
pq.push(&extra_candidate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -268,7 +255,7 @@ static int select_neighbors(MHNSW_Context *ctx,
|
|||
float best_top= best.top()->distance_to(target);
|
||||
while (pq.elements() && best.elements() < max_neighbor_connections)
|
||||
{
|
||||
const FVector *vec= pq.pop();
|
||||
const FVectorNode *vec= pq.pop();
|
||||
const float cur_dist= vec->distance_to(target);
|
||||
if (cur_dist < best_top)
|
||||
{
|
||||
|
@ -298,7 +285,7 @@ static int select_neighbors(MHNSW_Context *ctx,
|
|||
|
||||
|
||||
static void dbug_print_vec_ref(const char *prefix, uint layer,
|
||||
const FVectorRef &ref)
|
||||
const FVectorNode &ref)
|
||||
{
|
||||
#ifndef DBUG_OFF
|
||||
// TODO(cvicentiu) disable this in release build.
|
||||
|
@ -313,11 +300,11 @@ static void dbug_print_vec_ref(const char *prefix, uint layer,
|
|||
#endif
|
||||
}
|
||||
|
||||
static void dbug_print_vec_neigh(uint layer, const List<FVectorRef> &neighbors)
|
||||
static void dbug_print_vec_neigh(uint layer, const List<FVectorNode> &neighbors)
|
||||
{
|
||||
#ifndef DBUG_OFF
|
||||
DBUG_PRINT("VECTOR", ("NEIGH: NUM: %d", neighbors.elements));
|
||||
for (const FVectorRef& ref : neighbors)
|
||||
for (const FVectorNode& ref : neighbors)
|
||||
{
|
||||
dbug_print_vec_ref("NEIGH: ", layer, ref);
|
||||
}
|
||||
|
@ -325,8 +312,8 @@ static void dbug_print_vec_neigh(uint layer, const List<FVectorRef> &neighbors)
|
|||
}
|
||||
|
||||
static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
|
||||
const FVectorRef &source_node,
|
||||
const List<FVectorRef> &new_neighbors)
|
||||
const FVectorNode &source_node,
|
||||
const List<FVectorNode> &new_neighbors)
|
||||
{
|
||||
TABLE *graph= ctx->table->hlindex;
|
||||
DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
|
||||
|
@ -378,14 +365,14 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
|
|||
static int update_second_degree_neighbors(MHNSW_Context *ctx,
|
||||
size_t layer_number,
|
||||
uint max_neighbors,
|
||||
const FVectorRef &source_node,
|
||||
const List<FVectorRef> &neighbors)
|
||||
const FVectorNode &source_node,
|
||||
const List<FVectorNode> &neighbors)
|
||||
{
|
||||
//dbug_print_vec_ref("Updating second degree neighbors", layer_number, source_node);
|
||||
//dbug_print_vec_neigh(layer_number, neighbors);
|
||||
for (const FVectorRef &neigh: neighbors) // XXX why this loop?
|
||||
for (const FVectorNode &neigh: neighbors) // XXX why this loop?
|
||||
{
|
||||
List<FVectorRef> new_neighbors;
|
||||
List<FVectorNode> new_neighbors;
|
||||
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
|
||||
return err;
|
||||
new_neighbors.push_back(&source_node, &ctx->root);
|
||||
|
@ -393,20 +380,17 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx,
|
|||
return err;
|
||||
}
|
||||
|
||||
for (const FVectorRef &neigh: neighbors)
|
||||
for (const FVectorNode &neigh: neighbors)
|
||||
{
|
||||
List<FVectorRef> new_neighbors;
|
||||
List<FVectorNode> new_neighbors;
|
||||
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
|
||||
return err;
|
||||
|
||||
if (new_neighbors.elements > max_neighbors)
|
||||
{
|
||||
// shrink the neighbors
|
||||
List<FVectorRef> selected;
|
||||
FVector *v= ctx->get_fvector_from_source(neigh);
|
||||
if (!v)
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
if (int err= select_neighbors(ctx, layer_number, *v,
|
||||
List<FVectorNode> selected;
|
||||
if (int err= select_neighbors(ctx, layer_number, neigh,
|
||||
new_neighbors, max_neighbors, &selected))
|
||||
return err;
|
||||
if (int err= write_neighbors(ctx, layer_number, neigh, selected))
|
||||
|
@ -420,8 +404,8 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx,
|
|||
|
||||
static int update_neighbors(MHNSW_Context *ctx,
|
||||
size_t layer_number, uint max_neighbors,
|
||||
const FVectorRef &source_node,
|
||||
const List<FVectorRef> &neighbors)
|
||||
const FVectorNode &source_node,
|
||||
const List<FVectorNode> &neighbors)
|
||||
{
|
||||
// 1. update node's neighbors
|
||||
if (int err= write_neighbors(ctx, layer_number, source_node, neighbors))
|
||||
|
@ -433,36 +417,35 @@ static int update_neighbors(MHNSW_Context *ctx,
|
|||
|
||||
|
||||
static int search_layer(MHNSW_Context *ctx, const FVector &target,
|
||||
const List<FVectorRef> &start_nodes,
|
||||
const List<FVectorNode> &start_nodes,
|
||||
uint max_candidates_return, size_t layer,
|
||||
List<FVectorRef> *result)
|
||||
List<FVectorNode> *result)
|
||||
{
|
||||
DBUG_ASSERT(start_nodes.elements > 0);
|
||||
DBUG_ASSERT(result->elements == 0);
|
||||
|
||||
Queue<FVector, const FVector> candidates;
|
||||
Queue<FVector, const FVector> best;
|
||||
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, FVectorRef::get_key);
|
||||
Queue<FVectorNode, const FVector> candidates;
|
||||
Queue<FVectorNode, const FVector> best;
|
||||
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
|
||||
|
||||
candidates.init(10000, false, cmp_vec, &target);
|
||||
best.init(max_candidates_return, true, cmp_vec, &target);
|
||||
|
||||
for (const FVectorRef &node : start_nodes)
|
||||
for (const FVectorNode &node : start_nodes)
|
||||
{
|
||||
FVector *v= ctx->get_fvector_from_source(node);
|
||||
candidates.push(v);
|
||||
candidates.push(&node);
|
||||
if (best.elements() < max_candidates_return)
|
||||
best.push(v);
|
||||
else if (v->distance_to(target) > best.top()->distance_to(target))
|
||||
best.replace_top(v);
|
||||
visited.insert(v);
|
||||
best.push(&node);
|
||||
else if (node.distance_to(target) > best.top()->distance_to(target))
|
||||
best.replace_top(&node);
|
||||
visited.insert(&node);
|
||||
dbug_print_vec_ref("INSERTING node in visited: ", layer, node);
|
||||
}
|
||||
|
||||
float furthest_best= best.top()->distance_to(target);
|
||||
while (candidates.elements())
|
||||
{
|
||||
const FVector &cur_vec= *candidates.pop();
|
||||
const FVectorNode &cur_vec= *candidates.pop();
|
||||
float cur_distance= cur_vec.distance_to(target);
|
||||
if (cur_distance > furthest_best && best.elements() == max_candidates_return)
|
||||
{
|
||||
|
@ -470,26 +453,25 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target,
|
|||
// Can't get better.
|
||||
}
|
||||
|
||||
List<FVectorRef> neighbors;
|
||||
List<FVectorNode> neighbors;
|
||||
get_neighbors(ctx, layer, cur_vec, &neighbors);
|
||||
|
||||
for (const FVectorRef &neigh: neighbors)
|
||||
for (const FVectorNode &neigh: neighbors)
|
||||
{
|
||||
if (visited.find(&neigh))
|
||||
continue;
|
||||
|
||||
FVector *clone= ctx->get_fvector_from_source(neigh);
|
||||
visited.insert(clone);
|
||||
visited.insert(&neigh);
|
||||
if (best.elements() < max_candidates_return)
|
||||
{
|
||||
candidates.push(clone);
|
||||
best.push(clone);
|
||||
candidates.push(&neigh);
|
||||
best.push(&neigh);
|
||||
furthest_best= best.top()->distance_to(target);
|
||||
}
|
||||
else if (clone->distance_to(target) < furthest_best)
|
||||
else if (neigh.distance_to(target) < furthest_best)
|
||||
{
|
||||
best.replace_top(clone);
|
||||
candidates.push(clone);
|
||||
best.replace_top(&neigh);
|
||||
candidates.push(&neigh);
|
||||
furthest_best= best.top()->distance_to(target);
|
||||
}
|
||||
}
|
||||
|
@ -562,34 +544,32 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
|
|||
|
||||
// First insert!
|
||||
h->position(table->record[0]);
|
||||
return write_neighbors(&ctx, 0, {h->ref, h->ref_length}, {});
|
||||
return write_neighbors(&ctx, 0, {&ctx, h->ref}, {});
|
||||
}
|
||||
|
||||
longlong max_layer= graph->field[0]->val_int();
|
||||
|
||||
h->position(table->record[0]);
|
||||
|
||||
List<FVectorRef> candidates;
|
||||
List<FVectorRef> start_nodes;
|
||||
List<FVectorNode> candidates;
|
||||
List<FVectorNode> start_nodes;
|
||||
String ref_str, *ref_ptr;
|
||||
|
||||
ref_ptr= graph->field[1]->val_str(&ref_str);
|
||||
FVectorRef start_node_ref{ref_ptr->ptr(), ref_ptr->length()};
|
||||
FVectorNode start_node(&ctx, ref_ptr->ptr());
|
||||
|
||||
// TODO(cvicentiu) use a random start node in last layer.
|
||||
// XXX or may be *all* nodes in the last layer? there should be few
|
||||
if (start_nodes.push_back(&start_node_ref, &ctx.root))
|
||||
if (start_nodes.push_back(&start_node, &ctx.root))
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
|
||||
FVector *v= ctx.get_fvector_from_source(start_node_ref);
|
||||
if (!v)
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
if (int err= start_node.instantiate_vector())
|
||||
return err;
|
||||
|
||||
if (v->size_of() != res->length())
|
||||
if (ctx.vec_len * sizeof(float) != res->length())
|
||||
return bad_value_on_insert(vec_field);
|
||||
|
||||
FVector target;
|
||||
target.init(&ctx.root, h->ref, h->ref_length, res->ptr(), res->length());
|
||||
FVectorNode target(&ctx, h->ref, res->ptr());
|
||||
|
||||
double new_num= my_rnd(&thd->rand);
|
||||
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
|
||||
|
@ -609,7 +589,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
|
|||
for (longlong cur_layer= std::min(max_layer, new_node_layer);
|
||||
cur_layer >= 0; cur_layer--)
|
||||
{
|
||||
List<FVectorRef> neighbors;
|
||||
List<FVectorNode> neighbors;
|
||||
if (int err= search_layer(&ctx, target, start_nodes,
|
||||
thd->variables.hnsw_ef_constructor, cur_layer,
|
||||
&candidates))
|
||||
|
@ -666,33 +646,29 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
|
|||
|
||||
longlong max_layer= graph->field[0]->val_int();
|
||||
|
||||
List<FVectorRef> candidates; // XXX List? not Queue by distance?
|
||||
List<FVectorRef> start_nodes;
|
||||
String ref_str, *ref_ptr;
|
||||
List<FVectorNode> candidates; // XXX List? not Queue by distance?
|
||||
List<FVectorNode> start_nodes;
|
||||
String ref_str, *ref_ptr= graph->field[1]->val_str(&ref_str);
|
||||
|
||||
ref_ptr= graph->field[1]->val_str(&ref_str);
|
||||
FVectorRef start_node_ref{ref_ptr->ptr(), ref_ptr->length()};
|
||||
FVectorNode start_node(&ctx, ref_ptr->ptr());
|
||||
|
||||
// TODO(cvicentiu) use a random start node in last layer.
|
||||
// XXX or may be *all* nodes in the last layer? there should be few
|
||||
if (start_nodes.push_back(&start_node_ref, &ctx.root))
|
||||
if (start_nodes.push_back(&start_node, &ctx.root))
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
|
||||
FVector *v= ctx.get_fvector_from_source(start_node_ref);
|
||||
if (!v)
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
if (int err= start_node.instantiate_vector())
|
||||
return err;
|
||||
|
||||
/*
|
||||
if the query vector is NULL or invalid, VEC_DISTANCE will return
|
||||
NULL, so the result is basically unsorted, we can return rows
|
||||
in any order. For simplicity let's sort by the start_node.
|
||||
*/
|
||||
if (!res || v->size_of() != res->length())
|
||||
if (!res || ctx.vec_len * sizeof(float) != res->length())
|
||||
res= vec_field->val_str(&buf);
|
||||
|
||||
FVector target;
|
||||
if (target.init(&ctx.root, h->ref, h->ref_length, res->ptr(), res->length()))
|
||||
return HA_ERR_OUT_OF_MEM;
|
||||
FVector target(&ctx, res->ptr());
|
||||
|
||||
ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit?
|
||||
thd->variables.hnsw_ef_search, limit);
|
||||
|
|
Loading…
Reference in a new issue