mhnsw: cache neighbors too

This commit is contained in:
Sergei Golubchik 2024-06-04 23:06:44 +02:00
parent b492025c6c
commit 27bfa21a58

View file

@ -58,13 +58,17 @@ class FVectorNode: public FVector
{
private:
uchar *ref;
List<FVectorNode> *neighbors= nullptr;
char *neighbors_read= 0;
public:
FVectorNode(MHNSW_Context *ctx_, const void *ref_);
FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_);
float distance_to(const FVector &other) const;
int instantiate_vector();
int instantiate_neighbors(size_t layer);
size_t get_ref_len() const;
uchar *get_ref() const { return ref; }
List<FVectorNode> &get_neighbors(size_t layer) const;
bool is_new() const;
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
@ -130,6 +134,55 @@ int FVectorNode::instantiate_vector()
return 0;
}
int FVectorNode::instantiate_neighbors(size_t layer)
{
if (!neighbors)
{
neighbors= new (&ctx->root) List<FVectorNode>[layer+1];
neighbors_read= (char*)alloc_root(&ctx->root, layer+1);
bzero(neighbors_read, layer+1);
}
if (!neighbors_read[layer])
{
if (!is_new())
{
TABLE *graph= ctx->table->hlindex;
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
const size_t ref_len= get_ref_len();
graph->field[0]->store(layer, false);
graph->field[1]->store_binary(ref, ref_len);
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
if (int err= graph->file->ha_index_read_map(graph->record[0], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT))
return err;
String strbuf, *str= graph->field[2]->val_str(&strbuf);
const char *neigh_arr_bytes= str->ptr();
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
if (number_of_neighbors * ref_len + HNSW_MAX_M_WIDTH != str->length())
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
for (uint i= 0; i < number_of_neighbors; i++)
{
FVectorNode *neigh= ctx->get_node(pos);
neighbors[layer].push_back(neigh, &ctx->root);
pos+= ref_len;
}
}
neighbors_read[layer]= 1;
}
return 0;
}
List<FVectorNode> &FVectorNode::get_neighbors(size_t layer) const
{
const_cast<FVectorNode*>(this)->instantiate_neighbors(layer);
return neighbors[layer];
}
size_t FVectorNode::get_ref_len() const
{
return ctx->table->file->ref_length;
@ -172,44 +225,8 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod
const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
const bool EXTEND_CANDIDATES=true; // XXX or false?
static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorNode &source_node,
List<FVectorNode> *neighbors)
{
TABLE *graph= ctx->table->hlindex;
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
graph->field[0]->store(layer_number, false);
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
if (int err= graph->file->ha_index_read_map(graph->record[0], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT))
return err;
String strbuf, *str= graph->field[2]->val_str(&strbuf);
// mhnsw_insert() guarantees that all ref have the same length
uint ref_length= source_node.get_ref_len();
const char *neigh_arr_bytes= str->ptr();
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
if (number_of_neighbors * ref_length + HNSW_MAX_M_WIDTH != str->length())
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
for (uint i= 0; i < number_of_neighbors; i++)
{
FVectorNode *neigh= ctx->get_node(pos);
neighbors->push_back(neigh, &ctx->root);
pos+= ref_length;
}
return 0;
}
static int select_neighbors(MHNSW_Context *ctx,
size_t layer_number, const FVector &target,
size_t layer, const FVector &target,
const List<FVectorNode> &candidates,
size_t max_neighbor_connections,
List<FVectorNode> *neighbors)
@ -242,11 +259,7 @@ static int select_neighbors(MHNSW_Context *ctx,
{
for (const FVectorNode &candidate : candidates)
{
List<FVectorNode> candidate_neighbors;
if (int err= get_neighbors(ctx, layer_number, candidate,
&candidate_neighbors))
return err;
for (const FVectorNode &extra_candidate : candidate_neighbors)
for (const FVectorNode &extra_candidate : candidate.get_neighbors(layer))
{
if (visited.find(&extra_candidate))
continue;
@ -318,7 +331,7 @@ static void dbug_print_vec_neigh(uint layer, const List<FVectorNode> &neighbors)
#endif
}
static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
static int write_neighbors(MHNSW_Context *ctx, size_t layer,
const FVectorNode &source_node,
const List<FVectorNode> &new_neighbors)
{
@ -341,19 +354,19 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
pos+= node.get_ref_len();
}
graph->field[0]->store(layer_number, false);
graph->field[0]->store(layer, false);
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
if (source_node.is_new())
{
dbug_print_vec_ref("INSERT ", layer_number, source_node);
dbug_print_vec_ref("INSERT ", layer, source_node);
err= graph->file->ha_write_row(graph->record[0]);
}
else
{
dbug_print_vec_ref("UPDATE ", layer_number, source_node);
dbug_print_vec_neigh(layer_number, new_neighbors);
dbug_print_vec_ref("UPDATE ", layer, source_node);
dbug_print_vec_neigh(layer, new_neighbors);
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
@ -369,39 +382,33 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
}
static int update_second_degree_neighbors(MHNSW_Context *ctx,
size_t layer_number,
static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors,
const FVectorNode &source_node,
const List<FVectorNode> &neighbors)
{
//dbug_print_vec_ref("Updating second degree neighbors", layer_number, source_node);
//dbug_print_vec_neigh(layer_number, neighbors);
//dbug_print_vec_ref("Updating second degree neighbors", layer, source_node);
//dbug_print_vec_neigh(layer, neighbors);
for (const FVectorNode &neigh: neighbors) // XXX why this loop?
{
List<FVectorNode> new_neighbors;
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
return err;
new_neighbors.push_back(&source_node, &ctx->root);
if (int err= write_neighbors(ctx, layer_number, neigh, new_neighbors))
neigh.get_neighbors(layer).push_back(&source_node, &ctx->root);
if (int err= write_neighbors(ctx, layer, neigh, neigh.get_neighbors(layer)))
return err;
}
for (const FVectorNode &neigh: neighbors)
{
List<FVectorNode> new_neighbors;
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
return err;
if (new_neighbors.elements > max_neighbors)
if (neigh.get_neighbors(layer).elements > max_neighbors)
{
// shrink the neighbors
List<FVectorNode> selected;
if (int err= select_neighbors(ctx, layer_number, neigh,
new_neighbors, max_neighbors, &selected))
if (int err= select_neighbors(ctx, layer, neigh,
neigh.get_neighbors(layer),
max_neighbors, &selected))
return err;
if (int err= write_neighbors(ctx, layer_number, neigh, selected))
if (int err= write_neighbors(ctx, layer, neigh, selected))
return err;
// XXX neigh.get_neighbors(layer)= selected;
}
}
@ -410,15 +417,15 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx,
static int update_neighbors(MHNSW_Context *ctx,
size_t layer_number, uint max_neighbors,
size_t layer, uint max_neighbors,
const FVectorNode &source_node,
const List<FVectorNode> &neighbors)
{
// 1. update node's neighbors
if (int err= write_neighbors(ctx, layer_number, source_node, neighbors))
if (int err= write_neighbors(ctx, layer, source_node, neighbors))
return err;
// 2. update node's neighbors' neighbors (shrink before update)
return update_second_degree_neighbors(ctx, layer_number,
return update_second_degree_neighbors(ctx, layer,
max_neighbors, source_node, neighbors);
}
@ -461,10 +468,7 @@ static int search_layer(MHNSW_Context *ctx,
// Can't get better.
}
List<FVectorNode> neighbors;
get_neighbors(ctx, layer, cur_vec, &neighbors);
for (const FVectorNode &neigh: neighbors)
for (const FVectorNode &neigh: cur_vec.get_neighbors(layer))
{
if (visited.find(&neigh))
continue;
@ -483,7 +487,6 @@ static int search_layer(MHNSW_Context *ctx,
furthest_best= best.top()->distance_to(target);
}
}
neighbors.empty();
}
DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements()));