mirror of
https://github.com/MariaDB/server.git
synced 2025-01-16 03:52:35 +01:00
mhnsw: cache neighbors too
This commit is contained in:
parent
b492025c6c
commit
27bfa21a58
1 changed files with 75 additions and 72 deletions
|
@ -58,13 +58,17 @@ class FVectorNode: public FVector
|
|||
{
|
||||
private:
|
||||
uchar *ref;
|
||||
List<FVectorNode> *neighbors= nullptr;
|
||||
char *neighbors_read= 0;
|
||||
public:
|
||||
FVectorNode(MHNSW_Context *ctx_, const void *ref_);
|
||||
FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_);
|
||||
float distance_to(const FVector &other) const;
|
||||
int instantiate_vector();
|
||||
int instantiate_neighbors(size_t layer);
|
||||
size_t get_ref_len() const;
|
||||
uchar *get_ref() const { return ref; }
|
||||
List<FVectorNode> &get_neighbors(size_t layer) const;
|
||||
bool is_new() const;
|
||||
|
||||
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
|
||||
|
@ -130,6 +134,55 @@ int FVectorNode::instantiate_vector()
|
|||
return 0;
|
||||
}
|
||||
|
||||
int FVectorNode::instantiate_neighbors(size_t layer)
|
||||
{
|
||||
if (!neighbors)
|
||||
{
|
||||
neighbors= new (&ctx->root) List<FVectorNode>[layer+1];
|
||||
neighbors_read= (char*)alloc_root(&ctx->root, layer+1);
|
||||
bzero(neighbors_read, layer+1);
|
||||
}
|
||||
if (!neighbors_read[layer])
|
||||
{
|
||||
if (!is_new())
|
||||
{
|
||||
TABLE *graph= ctx->table->hlindex;
|
||||
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
|
||||
const size_t ref_len= get_ref_len();
|
||||
|
||||
graph->field[0]->store(layer, false);
|
||||
graph->field[1]->store_binary(ref, ref_len);
|
||||
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
|
||||
if (int err= graph->file->ha_index_read_map(graph->record[0], key,
|
||||
HA_WHOLE_KEY, HA_READ_KEY_EXACT))
|
||||
return err;
|
||||
|
||||
String strbuf, *str= graph->field[2]->val_str(&strbuf);
|
||||
const char *neigh_arr_bytes= str->ptr();
|
||||
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
|
||||
if (number_of_neighbors * ref_len + HNSW_MAX_M_WIDTH != str->length())
|
||||
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
|
||||
|
||||
const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
|
||||
for (uint i= 0; i < number_of_neighbors; i++)
|
||||
{
|
||||
FVectorNode *neigh= ctx->get_node(pos);
|
||||
neighbors[layer].push_back(neigh, &ctx->root);
|
||||
pos+= ref_len;
|
||||
}
|
||||
}
|
||||
neighbors_read[layer]= 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
List<FVectorNode> &FVectorNode::get_neighbors(size_t layer) const
|
||||
{
|
||||
const_cast<FVectorNode*>(this)->instantiate_neighbors(layer);
|
||||
return neighbors[layer];
|
||||
}
|
||||
|
||||
size_t FVectorNode::get_ref_len() const
|
||||
{
|
||||
return ctx->table->file->ref_length;
|
||||
|
@ -172,44 +225,8 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod
|
|||
const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
|
||||
const bool EXTEND_CANDIDATES=true; // XXX or false?
|
||||
|
||||
static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
|
||||
const FVectorNode &source_node,
|
||||
List<FVectorNode> *neighbors)
|
||||
{
|
||||
TABLE *graph= ctx->table->hlindex;
|
||||
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
|
||||
|
||||
graph->field[0]->store(layer_number, false);
|
||||
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
|
||||
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
|
||||
if (int err= graph->file->ha_index_read_map(graph->record[0], key,
|
||||
HA_WHOLE_KEY, HA_READ_KEY_EXACT))
|
||||
return err;
|
||||
|
||||
String strbuf, *str= graph->field[2]->val_str(&strbuf);
|
||||
|
||||
// mhnsw_insert() guarantees that all ref have the same length
|
||||
uint ref_length= source_node.get_ref_len();
|
||||
|
||||
const char *neigh_arr_bytes= str->ptr();
|
||||
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
|
||||
if (number_of_neighbors * ref_length + HNSW_MAX_M_WIDTH != str->length())
|
||||
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
|
||||
|
||||
const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
|
||||
for (uint i= 0; i < number_of_neighbors; i++)
|
||||
{
|
||||
FVectorNode *neigh= ctx->get_node(pos);
|
||||
neighbors->push_back(neigh, &ctx->root);
|
||||
pos+= ref_length;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
static int select_neighbors(MHNSW_Context *ctx,
|
||||
size_t layer_number, const FVector &target,
|
||||
size_t layer, const FVector &target,
|
||||
const List<FVectorNode> &candidates,
|
||||
size_t max_neighbor_connections,
|
||||
List<FVectorNode> *neighbors)
|
||||
|
@ -242,11 +259,7 @@ static int select_neighbors(MHNSW_Context *ctx,
|
|||
{
|
||||
for (const FVectorNode &candidate : candidates)
|
||||
{
|
||||
List<FVectorNode> candidate_neighbors;
|
||||
if (int err= get_neighbors(ctx, layer_number, candidate,
|
||||
&candidate_neighbors))
|
||||
return err;
|
||||
for (const FVectorNode &extra_candidate : candidate_neighbors)
|
||||
for (const FVectorNode &extra_candidate : candidate.get_neighbors(layer))
|
||||
{
|
||||
if (visited.find(&extra_candidate))
|
||||
continue;
|
||||
|
@ -318,7 +331,7 @@ static void dbug_print_vec_neigh(uint layer, const List<FVectorNode> &neighbors)
|
|||
#endif
|
||||
}
|
||||
|
||||
static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
|
||||
static int write_neighbors(MHNSW_Context *ctx, size_t layer,
|
||||
const FVectorNode &source_node,
|
||||
const List<FVectorNode> &new_neighbors)
|
||||
{
|
||||
|
@ -341,19 +354,19 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
|
|||
pos+= node.get_ref_len();
|
||||
}
|
||||
|
||||
graph->field[0]->store(layer_number, false);
|
||||
graph->field[0]->store(layer, false);
|
||||
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
|
||||
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
|
||||
|
||||
if (source_node.is_new())
|
||||
{
|
||||
dbug_print_vec_ref("INSERT ", layer_number, source_node);
|
||||
dbug_print_vec_ref("INSERT ", layer, source_node);
|
||||
err= graph->file->ha_write_row(graph->record[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dbug_print_vec_ref("UPDATE ", layer_number, source_node);
|
||||
dbug_print_vec_neigh(layer_number, new_neighbors);
|
||||
dbug_print_vec_ref("UPDATE ", layer, source_node);
|
||||
dbug_print_vec_neigh(layer, new_neighbors);
|
||||
|
||||
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
|
||||
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
|
||||
|
@ -369,39 +382,33 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
|
|||
}
|
||||
|
||||
|
||||
static int update_second_degree_neighbors(MHNSW_Context *ctx,
|
||||
size_t layer_number,
|
||||
static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
|
||||
uint max_neighbors,
|
||||
const FVectorNode &source_node,
|
||||
const List<FVectorNode> &neighbors)
|
||||
{
|
||||
//dbug_print_vec_ref("Updating second degree neighbors", layer_number, source_node);
|
||||
//dbug_print_vec_neigh(layer_number, neighbors);
|
||||
//dbug_print_vec_ref("Updating second degree neighbors", layer, source_node);
|
||||
//dbug_print_vec_neigh(layer, neighbors);
|
||||
for (const FVectorNode &neigh: neighbors) // XXX why this loop?
|
||||
{
|
||||
List<FVectorNode> new_neighbors;
|
||||
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
|
||||
return err;
|
||||
new_neighbors.push_back(&source_node, &ctx->root);
|
||||
if (int err= write_neighbors(ctx, layer_number, neigh, new_neighbors))
|
||||
neigh.get_neighbors(layer).push_back(&source_node, &ctx->root);
|
||||
if (int err= write_neighbors(ctx, layer, neigh, neigh.get_neighbors(layer)))
|
||||
return err;
|
||||
}
|
||||
|
||||
for (const FVectorNode &neigh: neighbors)
|
||||
{
|
||||
List<FVectorNode> new_neighbors;
|
||||
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
|
||||
return err;
|
||||
|
||||
if (new_neighbors.elements > max_neighbors)
|
||||
if (neigh.get_neighbors(layer).elements > max_neighbors)
|
||||
{
|
||||
// shrink the neighbors
|
||||
List<FVectorNode> selected;
|
||||
if (int err= select_neighbors(ctx, layer_number, neigh,
|
||||
new_neighbors, max_neighbors, &selected))
|
||||
if (int err= select_neighbors(ctx, layer, neigh,
|
||||
neigh.get_neighbors(layer),
|
||||
max_neighbors, &selected))
|
||||
return err;
|
||||
if (int err= write_neighbors(ctx, layer_number, neigh, selected))
|
||||
if (int err= write_neighbors(ctx, layer, neigh, selected))
|
||||
return err;
|
||||
// XXX neigh.get_neighbors(layer)= selected;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -410,15 +417,15 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx,
|
|||
|
||||
|
||||
static int update_neighbors(MHNSW_Context *ctx,
|
||||
size_t layer_number, uint max_neighbors,
|
||||
size_t layer, uint max_neighbors,
|
||||
const FVectorNode &source_node,
|
||||
const List<FVectorNode> &neighbors)
|
||||
{
|
||||
// 1. update node's neighbors
|
||||
if (int err= write_neighbors(ctx, layer_number, source_node, neighbors))
|
||||
if (int err= write_neighbors(ctx, layer, source_node, neighbors))
|
||||
return err;
|
||||
// 2. update node's neighbors' neighbors (shrink before update)
|
||||
return update_second_degree_neighbors(ctx, layer_number,
|
||||
return update_second_degree_neighbors(ctx, layer,
|
||||
max_neighbors, source_node, neighbors);
|
||||
}
|
||||
|
||||
|
@ -461,10 +468,7 @@ static int search_layer(MHNSW_Context *ctx,
|
|||
// Can't get better.
|
||||
}
|
||||
|
||||
List<FVectorNode> neighbors;
|
||||
get_neighbors(ctx, layer, cur_vec, &neighbors);
|
||||
|
||||
for (const FVectorNode &neigh: neighbors)
|
||||
for (const FVectorNode &neigh: cur_vec.get_neighbors(layer))
|
||||
{
|
||||
if (visited.find(&neigh))
|
||||
continue;
|
||||
|
@ -483,7 +487,6 @@ static int search_layer(MHNSW_Context *ctx,
|
|||
furthest_best= best.top()->distance_to(target);
|
||||
}
|
||||
}
|
||||
neighbors.empty();
|
||||
}
|
||||
DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements()));
|
||||
|
||||
|
|
Loading…
Reference in a new issue