diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index 4aa6499ba34..6451b4c355f 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -58,13 +58,17 @@ class FVectorNode: public FVector { private: uchar *ref; + List<FVectorNode> *neighbors= nullptr; + char *neighbors_read= 0; public: FVectorNode(MHNSW_Context *ctx_, const void *ref_); FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_); float distance_to(const FVector &other) const; int instantiate_vector(); + int instantiate_neighbors(size_t layer); size_t get_ref_len() const; uchar *get_ref() const { return ref; } + List<FVectorNode> &get_neighbors(size_t layer) const; bool is_new() const; static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool); @@ -130,6 +134,55 @@ int FVectorNode::instantiate_vector() return 0; } +int FVectorNode::instantiate_neighbors(size_t layer) +{ + if (!neighbors) + { + neighbors= new (&ctx->root) List<FVectorNode>[layer+1]; + neighbors_read= (char*)alloc_root(&ctx->root, layer+1); + bzero(neighbors_read, layer+1); + } + if (!neighbors_read[layer]) + { + if (!is_new()) + { + TABLE *graph= ctx->table->hlindex; + uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); + const size_t ref_len= get_ref_len(); + + graph->field[0]->store(layer, false); + graph->field[1]->store_binary(ref, ref_len); + key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); + if (int err= graph->file->ha_index_read_map(graph->record[0], key, + HA_WHOLE_KEY, HA_READ_KEY_EXACT)) + return err; + + String strbuf, *str= graph->field[2]->val_str(&strbuf); + const char *neigh_arr_bytes= str->ptr(); + uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes); + if (number_of_neighbors * ref_len + HNSW_MAX_M_WIDTH != str->length()) + return HA_ERR_CRASHED; // should not happen, corrupted HNSW index + + const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH; + for (uint i= 0; i < number_of_neighbors; i++) + { + FVectorNode *neigh= ctx->get_node(pos); + neighbors[layer].push_back(neigh, &ctx->root); + pos+= ref_len; + } + } + neighbors_read[layer]= 1; + } + + return 0; +} + +List<FVectorNode> &FVectorNode::get_neighbors(size_t layer) const +{ + const_cast<FVectorNode*>(this)->instantiate_neighbors(layer); + return neighbors[layer]; +} + size_t FVectorNode::get_ref_len() const { return ctx->table->file->ref_length; @@ -172,44 +225,8 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why? const bool EXTEND_CANDIDATES=true; // XXX or false? -static int get_neighbors(MHNSW_Context *ctx, size_t layer_number, - const FVectorNode &source_node, - List<FVectorNode> *neighbors) -{ - TABLE *graph= ctx->table->hlindex; - uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); - - graph->field[0]->store(layer_number, false); - graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len()); - key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); - if (int err= graph->file->ha_index_read_map(graph->record[0], key, - HA_WHOLE_KEY, HA_READ_KEY_EXACT)) - return err; - - String strbuf, *str= graph->field[2]->val_str(&strbuf); - - // mhnsw_insert() guarantees that all ref have the same length - uint ref_length= source_node.get_ref_len(); - - const char *neigh_arr_bytes= str->ptr(); - uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes); - if (number_of_neighbors * ref_length + HNSW_MAX_M_WIDTH != str->length()) - return HA_ERR_CRASHED; // should not happen, corrupted HNSW index - - const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH; - for (uint i= 0; i < number_of_neighbors; i++) - { - FVectorNode *neigh= ctx->get_node(pos); - neighbors->push_back(neigh, &ctx->root); - pos+= ref_length; - } - - return 0; -} - - static int select_neighbors(MHNSW_Context *ctx, - size_t layer_number, const FVector &target, + size_t layer, const FVector &target, const List<FVectorNode> &candidates, size_t max_neighbor_connections, List<FVectorNode> *neighbors) @@ -242,11 +259,7 @@ static int select_neighbors(MHNSW_Context *ctx, { for (const FVectorNode &candidate : candidates) { - List<FVectorNode> candidate_neighbors; - if (int err= get_neighbors(ctx, layer_number, candidate, - &candidate_neighbors)) - return err; - for (const FVectorNode &extra_candidate : candidate_neighbors) + for (const FVectorNode &extra_candidate : candidate.get_neighbors(layer)) { if (visited.find(&extra_candidate)) continue; @@ -318,7 +331,7 @@ static void dbug_print_vec_neigh(uint layer, const List<FVectorNode> &neighbors) #endif } -static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, +static int write_neighbors(MHNSW_Context *ctx, size_t layer, const FVectorNode &source_node, const List<FVectorNode> &new_neighbors) { @@ -341,19 +354,19 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, pos+= node.get_ref_len(); } - graph->field[0]->store(layer_number, false); + graph->field[0]->store(layer, false); graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len()); graph->field[2]->store_binary(neighbor_array_bytes, total_size); if (source_node.is_new()) { - dbug_print_vec_ref("INSERT ", layer_number, source_node); + dbug_print_vec_ref("INSERT ", layer, source_node); err= graph->file->ha_write_row(graph->record[0]); } else { - dbug_print_vec_ref("UPDATE ", layer_number, source_node); - dbug_print_vec_neigh(layer_number, new_neighbors); + dbug_print_vec_ref("UPDATE ", layer, source_node); + dbug_print_vec_neigh(layer, new_neighbors); uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); @@ -369,39 +382,33 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, } -static int update_second_degree_neighbors(MHNSW_Context *ctx, - size_t layer_number, +static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer, uint max_neighbors, const FVectorNode &source_node, const List<FVectorNode> &neighbors) { - //dbug_print_vec_ref("Updating second degree neighbors", layer_number, source_node); - //dbug_print_vec_neigh(layer_number, neighbors); + //dbug_print_vec_ref("Updating second degree neighbors", layer, source_node); + //dbug_print_vec_neigh(layer, neighbors); for (const FVectorNode &neigh: neighbors) // XXX why this loop? { - List<FVectorNode> new_neighbors; - if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors)) - return err; - new_neighbors.push_back(&source_node, &ctx->root); - if (int err= write_neighbors(ctx, layer_number, neigh, new_neighbors)) + neigh.get_neighbors(layer).push_back(&source_node, &ctx->root); + if (int err= write_neighbors(ctx, layer, neigh, neigh.get_neighbors(layer))) return err; } for (const FVectorNode &neigh: neighbors) { - List<FVectorNode> new_neighbors; - if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors)) - return err; - - if (new_neighbors.elements > max_neighbors) + if (neigh.get_neighbors(layer).elements > max_neighbors) { // shrink the neighbors List<FVectorNode> selected; - if (int err= select_neighbors(ctx, layer_number, neigh, - new_neighbors, max_neighbors, &selected)) + if (int err= select_neighbors(ctx, layer, neigh, + neigh.get_neighbors(layer), + max_neighbors, &selected)) return err; - if (int err= write_neighbors(ctx, layer_number, neigh, selected)) + if (int err= write_neighbors(ctx, layer, neigh, selected)) return err; + // XXX neigh.get_neighbors(layer)= selected; } } @@ -410,15 +417,15 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, static int update_neighbors(MHNSW_Context *ctx, - size_t layer_number, uint max_neighbors, + size_t layer, uint max_neighbors, const FVectorNode &source_node, const List<FVectorNode> &neighbors) { // 1. update node's neighbors - if (int err= write_neighbors(ctx, layer_number, source_node, neighbors)) + if (int err= write_neighbors(ctx, layer, source_node, neighbors)) return err; // 2. update node's neighbors' neighbors (shrink before update) - return update_second_degree_neighbors(ctx, layer_number, + return update_second_degree_neighbors(ctx, layer, max_neighbors, source_node, neighbors); } @@ -461,10 +468,7 @@ static int search_layer(MHNSW_Context *ctx, // Can't get better. } - List<FVectorNode> neighbors; - get_neighbors(ctx, layer, cur_vec, &neighbors); - - for (const FVectorNode &neigh: neighbors) + for (const FVectorNode &neigh: cur_vec.get_neighbors(layer)) { if (visited.find(&neigh)) continue; @@ -483,7 +487,6 @@ static int search_layer(MHNSW_Context *ctx, furthest_best= best.top()->distance_to(target); } } - neighbors.empty(); } DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements()));