/* Copyright (c) 2024, MariaDB plc This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; version 2 of the License. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA */ #include #include "key.h" // key_copy() #include "create_options.h" #include "vector_mhnsw.h" #include "item_vectorfunc.h" #include #include #include "bloom_filters.h" // Algorithm parameters static constexpr float alpha = 1.1f; static constexpr float generosity = 1.1f; static constexpr uint ef_construction= 10; static ulonglong mhnsw_cache_size; static MYSQL_SYSVAR_ULONGLONG(cache_size, mhnsw_cache_size, PLUGIN_VAR_RQCMDARG, "Size of the cache for the MHNSW vector index", nullptr, nullptr, 16*1024*1024, 1024*1024, SIZE_T_MAX, 1); static MYSQL_THDVAR_UINT(min_limit, PLUGIN_VAR_RQCMDARG, "Defines the minimal number of result candidates to look for in the " "vector index for ORDER BY ... LIMIT N queries. The search will never " "search for less rows than that, even if LIMIT is smaller. " "This notably improves the search quality at low LIMIT values, " "at the expense of search time", nullptr, nullptr, 20, 1, 65535, 1); static MYSQL_THDVAR_UINT(max_edges_per_node, PLUGIN_VAR_RQCMDARG, "Larger values means slower INSERT, larger index size and higher " "memory consumption, but better search results", nullptr, nullptr, 6, 3, 200, 1); enum metric_type : uint { EUCLIDEAN, COSINE }; static const char *distance_function_names[]= { "euclidean", "cosine", nullptr }; static TYPELIB distance_functions= CREATE_TYPELIB_FOR(distance_function_names); static MYSQL_THDVAR_ENUM(distance_function, PLUGIN_VAR_RQCMDARG, "Distance function to build the vector index for", nullptr, nullptr, EUCLIDEAN, &distance_functions); struct ha_index_option_struct { ulonglong M; // option struct does not support uint metric_type metric; }; enum Graph_table_fields { FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS }; enum Graph_table_indices { IDX_TREF, IDX_LAYER }; class MHNSW_Context; class FVectorNode; /* One vector, an array of coordinates in ctx->vec_len dimensions */ #pragma pack(push, 1) struct FVector { static constexpr size_t data_header= sizeof(float); static constexpr size_t alloc_header= data_header + sizeof(float); float abs2, scale; int16_t dims[4]; uchar *data() const { return (uchar*)(&scale); } static size_t data_size(size_t n) { return data_header + n*2; } static size_t data_to_value_size(size_t data_size) { return (data_size - data_header)*2; } static const FVector *create(metric_type metric, void *mem, const void *src, size_t src_len) { float scale=0, *v= (float *)src; size_t vec_len= src_len / sizeof(float); for (size_t i= 0; i < vec_len; i++) if (std::abs(scale) < std::abs(get_float(v + i))) scale= get_float(v + i); FVector *vec= align_ptr(mem); vec->scale= scale ? scale/32767 : 1; for (size_t i= 0; i < vec_len; i++) vec->dims[i] = static_cast(std::round(get_float(v + i) / vec->scale)); vec->postprocess(vec_len); if (metric == COSINE) { if (vec->abs2 > 0.0f) vec->scale/= std::sqrt(vec->abs2); vec->abs2= 1.0f; } return vec; } void postprocess(size_t vec_len) { fix_tail(vec_len); abs2= scale * scale * dot_product(dims, dims, vec_len) / 2; } #ifdef AVX2_IMPLEMENTATION /************* AVX2 *****************************************************/ static constexpr size_t AVX2_bytes= 256/8; static constexpr size_t AVX2_dims= AVX2_bytes/sizeof(int16_t); AVX2_IMPLEMENTATION static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) { typedef float v8f __attribute__((vector_size(AVX2_bytes))); union { v8f v; __m256 i; } tmp; __m256i *p1= (__m256i*)v1; __m256i *p2= (__m256i*)v2; v8f d= {0}; for (size_t i= 0; i < (len + AVX2_dims-1)/AVX2_dims; p1++, p2++, i++) { tmp.i= _mm256_cvtepi32_ps(_mm256_madd_epi16(*p1, *p2)); d+= tmp.v; } return d[0] + d[1] + d[2] + d[3] + d[4] + d[5] + d[6] + d[7]; } AVX2_IMPLEMENTATION static size_t alloc_size(size_t n) { return alloc_header + MY_ALIGN(n*2, AVX2_bytes) + AVX2_bytes - 1; } AVX2_IMPLEMENTATION static FVector *align_ptr(void *ptr) { return (FVector*)(MY_ALIGN(((intptr)ptr) + alloc_header, AVX2_bytes) - alloc_header); } AVX2_IMPLEMENTATION void fix_tail(size_t vec_len) { bzero(dims + vec_len, (MY_ALIGN(vec_len, AVX2_dims) - vec_len)*2); } #endif #ifdef AVX512_IMPLEMENTATION /************* AVX512 ****************************************************/ static constexpr size_t AVX512_bytes= 512/8; static constexpr size_t AVX512_dims= AVX512_bytes/sizeof(int16_t); AVX512_IMPLEMENTATION static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) { __m512i *p1= (__m512i*)v1; __m512i *p2= (__m512i*)v2; __m512 d= _mm512_setzero_ps(); for (size_t i= 0; i < (len + AVX512_dims-1)/AVX512_dims; p1++, p2++, i++) d= _mm512_add_ps(d, _mm512_cvtepi32_ps(_mm512_madd_epi16(*p1, *p2))); return _mm512_reduce_add_ps(d); } AVX512_IMPLEMENTATION static size_t alloc_size(size_t n) { return alloc_header + MY_ALIGN(n*2, AVX512_bytes) + AVX512_bytes - 1; } AVX512_IMPLEMENTATION static FVector *align_ptr(void *ptr) { return (FVector*)(MY_ALIGN(((intptr)ptr) + alloc_header, AVX512_bytes) - alloc_header); } AVX512_IMPLEMENTATION void fix_tail(size_t vec_len) { bzero(dims + vec_len, (MY_ALIGN(vec_len, AVX512_dims) - vec_len)*2); } #endif /************* no-SIMD default ******************************************/ DEFAULT_IMPLEMENTATION static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) { int64_t d= 0; for (size_t i= 0; i < len; i++) d+= int32_t(v1[i]) * int32_t(v2[i]); return static_cast(d); } DEFAULT_IMPLEMENTATION static size_t alloc_size(size_t n) { return alloc_header + n*2; } DEFAULT_IMPLEMENTATION static FVector *align_ptr(void *ptr) { return (FVector*)ptr; } DEFAULT_IMPLEMENTATION void fix_tail(size_t) { } float distance_to(const FVector *other, size_t vec_len) const { return abs2 + other->abs2 - scale * other->scale * dot_product(dims, other->dims, vec_len); } }; #pragma pack(pop) /* An array of pointers to graph nodes It's mainly used to store all neighbors of a given node on a given layer. An array is fixed size, 2*M for the zero layer, M for other layers see MHNSW_Context::max_neighbors(). Number of neighbors is zero-padded to multiples of 8 (for SIMD Bloom filter). Also used as a simply array of nodes in search_layer, the array size then is defined by ef or efConstruction. */ struct Neighborhood: public Sql_alloc { FVectorNode **links; size_t num; FVectorNode **init(FVectorNode **ptr, size_t n) { num= 0; links= ptr; n= MY_ALIGN(n, 8); bzero(ptr, n*sizeof(*ptr)); return ptr + n; } }; /* One node in a graph = one row in the graph table stores a vector itself, ref (= position) in the graph (= hlindex) table, a ref in the main table, and an array of Neighborhood's, one per layer. It's lazily initialized, may know only gref, everything else is loaded on demand. On the other hand, on INSERT the new node knows everything except gref - which only becomes known after ha_write_row. Allocated on memroot in two chunks. One is the same size for all nodes and stores FVectorNode object, gref, tref, and vector. The second stores neighbors, all Neighborhood's together, its size depends on the number of layers this node is on. There can be millions of nodes in the cache and the cache size is constrained by mhnsw_cache_size, so every byte matters here */ #pragma pack(push, 1) class FVectorNode { private: MHNSW_Context *ctx; const FVector *make_vec(const void *v); int alloc_neighborhood(uint8_t layer); public: const FVector *vec= nullptr; Neighborhood *neighbors= nullptr; uint8_t max_layer; bool stored:1, deleted:1; FVectorNode(MHNSW_Context *ctx_, const void *gref_); FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer, const void *vec_); float distance_to(const FVector *other) const; int load(TABLE *graph); int load_from_record(TABLE *graph); int save(TABLE *graph); size_t tref_len() const; size_t gref_len() const; uchar *gref() const; uchar *tref() const; void push_neighbor(size_t layer, FVectorNode *v); static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool); }; #pragma pack(pop) /* Shared algorithm context. The graph. Stored in TABLE_SHARE and on TABLE_SHARE::mem_root. Stores the complete graph in MHNSW_Context::root, The mapping gref->FVectorNode is in the node_cache. Both root and node_cache are protected by a cache_lock, but it's needed when loading nodes and is not used when the whole graph is in memory. Graph can be traversed concurrently by different threads, as traversal changes neither nodes nor the ctx. Nodes can be loaded concurrently by different threads, this is protected by a partitioned node_lock. reference counter allows flushing the graph without interrupting concurrent searches. MyISAM automatically gets exclusive write access because of the TL_WRITE, but InnoDB has to use a dedicated ctx->commit_lock for that */ class MHNSW_Context : public Sql_alloc { std::atomic refcnt{0}; mysql_mutex_t cache_lock; mysql_mutex_t node_lock[8]; void cache_internal(FVectorNode *node) { DBUG_ASSERT(node->stored); node_cache.insert(node); } void *alloc_node_internal() { return alloc_root(&root, sizeof(FVectorNode) + gref_len + tref_len + FVector::alloc_size(vec_len)); } protected: MEM_ROOT root; Hash_set node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key}; public: mysql_rwlock_t commit_lock; size_t vec_len= 0; size_t byte_len= 0; Atomic_relaxed ef_power{0.6}; // for the bloom filter size heuristic FVectorNode *start= 0; const uint tref_len; const uint gref_len; const uint M; metric_type metric; MHNSW_Context(TABLE *t) : tref_len(t->file->ref_length), gref_len(t->hlindex->file->ref_length), M(static_cast(t->s->key_info[t->s->keys].option_struct->M)), metric(t->s->key_info[t->s->keys].option_struct->metric) { mysql_rwlock_init(PSI_INSTRUMENT_ME, &commit_lock); mysql_mutex_init(PSI_INSTRUMENT_ME, &cache_lock, MY_MUTEX_INIT_FAST); for (uint i=0; i < array_elements(node_lock); i++) mysql_mutex_init(PSI_INSTRUMENT_ME, node_lock + i, MY_MUTEX_INIT_SLOW); init_alloc_root(PSI_INSTRUMENT_MEM, &root, 1024*1024, 0, MYF(0)); } virtual ~MHNSW_Context() { free_root(&root, MYF(0)); mysql_rwlock_destroy(&commit_lock); mysql_mutex_destroy(&cache_lock); for (size_t i=0; i < array_elements(node_lock); i++) mysql_mutex_destroy(node_lock + i); } uint lock_node(FVectorNode *ptr) { ulong nr1= 1, nr2= 4; my_hash_sort_bin(0, (uchar*)&ptr, sizeof(ptr), &nr1, &nr2); uint ticket= nr1 % array_elements(node_lock); mysql_mutex_lock(node_lock + ticket); return ticket; } void unlock_node(uint ticket) { mysql_mutex_unlock(node_lock + ticket); } uint max_neighbors(size_t layer) const { return (layer ? 1 : 2) * M; // heuristic from the paper } void set_lengths(size_t len) { byte_len= len; vec_len= len / sizeof(float); } static int acquire(MHNSW_Context **ctx, TABLE *table, bool for_update); static MHNSW_Context *get_from_share(TABLE_SHARE *share, TABLE *table); virtual void reset(TABLE_SHARE *share) { share->lock_share(); if (static_cast(share->hlindex->hlindex_data) == this) { share->hlindex->hlindex_data= nullptr; --refcnt; } share->unlock_share(); } void release(TABLE *table) { return release(table->file->has_transactions(), table->s); } virtual void release(bool can_commit, TABLE_SHARE *share) { if (can_commit) mysql_rwlock_unlock(&commit_lock); if (root_size(&root) > mhnsw_cache_size) reset(share); if (--refcnt == 0) this->~MHNSW_Context(); // XXX reuse } FVectorNode *get_node(const void *gref) { mysql_mutex_lock(&cache_lock); FVectorNode *node= node_cache.find(gref, gref_len); if (!node) { node= new (alloc_node_internal()) FVectorNode(this, gref); cache_internal(node); } mysql_mutex_unlock(&cache_lock); return node; } /* used on INSERT, gref isn't known, so cannot cache the node yet */ void *alloc_node() { mysql_mutex_lock(&cache_lock); auto p= alloc_node_internal(); mysql_mutex_unlock(&cache_lock); return p; } /* explicitly cache the node after alloc_node() */ void cache_node(FVectorNode *node) { mysql_mutex_lock(&cache_lock); cache_internal(node); mysql_mutex_unlock(&cache_lock); } /* find the node without creating, only used on merging trx->ctx */ FVectorNode *find_node(const void *gref) { mysql_mutex_lock(&cache_lock); FVectorNode *node= node_cache.find(gref, gref_len); mysql_mutex_unlock(&cache_lock); return node; } void *alloc_neighborhood(size_t max_layer) { mysql_mutex_lock(&cache_lock); auto p= alloc_root(&root, sizeof(Neighborhood)*(max_layer+1) + sizeof(FVectorNode*)*(MY_ALIGN(M, 4)*2 + MY_ALIGN(M,8)*max_layer)); mysql_mutex_unlock(&cache_lock); return p; } }; /* This is a non-shared context that exists within one transaction. At the end of the transaction it's either discarded (on rollback) or merged into the shared ctx (on commit). trx's are stored in thd->ha_data[] in a single-linked list, one instance of trx per TABLE_SHARE and allocated on the thd->transaction->mem_root */ class MHNSW_Trx : public MHNSW_Context { public: TABLE_SHARE *table_share; bool list_of_nodes_is_lost= false; MHNSW_Trx *next= nullptr; MHNSW_Trx(TABLE *table) : MHNSW_Context(table), table_share(table->s) {} void reset(TABLE_SHARE *) override { node_cache.clear(); free_root(&root, MYF(0)); start= 0; list_of_nodes_is_lost= true; } void release(bool, TABLE_SHARE *) override { if (root_size(&root) > mhnsw_cache_size) reset(nullptr); } static MHNSW_Trx *get_from_thd(TABLE *table, bool for_update); // it's okay in a transaction-local cache, there's no concurrent access Hash_set &get_cache() { return node_cache; } static transaction_participant tp; static int do_commit(THD *thd, bool); static int do_savepoint_rollback(THD *thd, void *); static int do_rollback(THD *thd, bool); }; struct transaction_participant MHNSW_Trx::tp= { 0, 0, 0, nullptr, /* close_connection */ [](THD *, void *){ return 0; }, /* savepoint_set */ MHNSW_Trx::do_savepoint_rollback, [](THD *thd){ return true; }, /*savepoint_rollback_can_release_mdl*/ nullptr, /*savepoint_release*/ MHNSW_Trx::do_commit, MHNSW_Trx::do_rollback, nullptr, /* prepare */ nullptr, /* recover */ nullptr, nullptr, /* commit/rollback_by_xid */ nullptr, nullptr, /* recover_rollback_by_xid/recovery_done */ nullptr, nullptr, nullptr, /* snapshot, commit/prepare_ordered */ nullptr, nullptr /* checkpoint, versioned */ }; int MHNSW_Trx::do_savepoint_rollback(THD *thd, void *) { for (auto trx= static_cast(thd_get_ha_data(thd, &tp)); trx; trx= trx->next) trx->reset(nullptr); return 0; } int MHNSW_Trx::do_rollback(THD *thd, bool) { MHNSW_Trx *trx_next; for (auto trx= static_cast(thd_get_ha_data(thd, &tp)); trx; trx= trx_next) { trx_next= trx->next; trx->~MHNSW_Trx(); } thd_set_ha_data(current_thd, &tp, nullptr); return 0; } int MHNSW_Trx::do_commit(THD *thd, bool) { MHNSW_Trx *trx_next; for (auto trx= static_cast(thd_get_ha_data(thd, &tp)); trx; trx= trx_next) { trx_next= trx->next; auto ctx= MHNSW_Context::get_from_share(trx->table_share, nullptr); if (ctx) { mysql_rwlock_wrlock(&ctx->commit_lock); if (trx->list_of_nodes_is_lost) ctx->reset(trx->table_share); else { // consider copying nodes from trx to shared cache when it makes sense // for ann_benchmarks it does not // also, consider flushing only changed nodes (a flag in the node) for (FVectorNode &from : trx->get_cache()) if (FVectorNode *node= ctx->find_node(from.gref())) node->vec= nullptr; ctx->start= nullptr; } ctx->release(true, trx->table_share); } trx->~MHNSW_Trx(); } thd_set_ha_data(current_thd, &tp, nullptr); return 0; } MHNSW_Trx *MHNSW_Trx::get_from_thd(TABLE *table, bool for_update) { if (!table->file->has_transactions()) return NULL; THD *thd= table->in_use; auto trx= static_cast(thd_get_ha_data(thd, &tp)); if (!for_update && !trx) return NULL; while (trx && trx->table_share != table->s) trx= trx->next; if (!trx) { trx= new (&thd->transaction->mem_root) MHNSW_Trx(table); trx->next= static_cast(thd_get_ha_data(thd, &tp)); thd_set_ha_data(thd, &tp, trx); if (!trx->next) { bool all= thd_test_options(thd, OPTION_NOT_AUTOCOMMIT | OPTION_BEGIN); trans_register_ha(thd, all, &tp, 0); } } return trx; } MHNSW_Context *MHNSW_Context::get_from_share(TABLE_SHARE *share, TABLE *table) { share->lock_share(); auto ctx= static_cast(share->hlindex->hlindex_data); if (!ctx && table) { ctx= new (&share->hlindex->mem_root) MHNSW_Context(table); if (!ctx) return nullptr; share->hlindex->hlindex_data= ctx; ctx->refcnt++; } if (ctx) ctx->refcnt++; share->unlock_share(); return ctx; } int MHNSW_Context::acquire(MHNSW_Context **ctx, TABLE *table, bool for_update) { TABLE *graph= table->hlindex; if (!(*ctx= MHNSW_Trx::get_from_thd(table, for_update))) { *ctx= MHNSW_Context::get_from_share(table->s, table); if (table->file->has_transactions()) mysql_rwlock_rdlock(&(*ctx)->commit_lock); } if ((*ctx)->start) return 0; if (int err= graph->file->ha_index_init(IDX_LAYER, 1)) return err; int err= graph->file->ha_index_last(graph->record[0]); graph->file->ha_index_end(); if (err) return err; graph->file->position(graph->record[0]); (*ctx)->set_lengths(FVector::data_to_value_size(graph->field[FIELD_VEC]->value_length())); (*ctx)->start= (*ctx)->get_node(graph->file->ref); return (*ctx)->start->load_from_record(graph); } /* copy the vector, preprocessed as needed */ const FVector *FVectorNode::make_vec(const void *v) { return FVector::create(ctx->metric, tref() + tref_len(), v, ctx->byte_len); } FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *gref_) : ctx(ctx_), stored(true), deleted(false) { memcpy(gref(), gref_, gref_len()); } FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer, const void *vec_) : ctx(ctx_), stored(false), deleted(false) { DBUG_ASSERT(tref_); memset(gref(), 0xff, gref_len()); // important: larger than any real gref memcpy(tref(), tref_, tref_len()); vec= make_vec(vec_); alloc_neighborhood(layer); } float FVectorNode::distance_to(const FVector *other) const { return vec->distance_to(other, ctx->vec_len); } int FVectorNode::alloc_neighborhood(uint8_t layer) { if (neighbors) return 0; max_layer= layer; neighbors= (Neighborhood*)ctx->alloc_neighborhood(layer); auto ptr= (FVectorNode**)(neighbors + (layer+1)); for (size_t i= 0; i <= layer; i++) ptr= neighbors[i].init(ptr, ctx->max_neighbors(i)); return 0; } int FVectorNode::load(TABLE *graph) { if (likely(vec)) return 0; DBUG_ASSERT(stored); // trx: consider loading nodes from shared, when it makes sense // for ann_benchmarks it does not if (int err= graph->file->ha_rnd_pos(graph->record[0], gref())) return err; return load_from_record(graph); } int FVectorNode::load_from_record(TABLE *graph) { DBUG_ASSERT(ctx->byte_len); uint ticket= ctx->lock_node(this); SCOPE_EXIT([this, ticket](){ ctx->unlock_node(ticket); }); if (vec) return 0; String buf, *v= graph->field[FIELD_TREF]->val_str(&buf); deleted= graph->field[FIELD_TREF]->is_null(); if (!deleted) { if (unlikely(v->length() != tref_len())) return my_errno= HA_ERR_CRASHED; memcpy(tref(), v->ptr(), v->length()); } v= graph->field[FIELD_VEC]->val_str(&buf); if (unlikely(!v)) return my_errno= HA_ERR_CRASHED; if (v->length() != FVector::data_size(ctx->vec_len)) return my_errno= HA_ERR_CRASHED; FVector *vec_ptr= FVector::align_ptr(tref() + tref_len()); memcpy(vec_ptr->data(), v->ptr(), v->length()); vec_ptr->postprocess(ctx->vec_len); longlong layer= graph->field[FIELD_LAYER]->val_int(); if (layer > 100) // 10e30 nodes at M=2, more at larger M's return my_errno= HA_ERR_CRASHED; if (int err= alloc_neighborhood(static_cast(layer))) return err; v= graph->field[FIELD_NEIGHBORS]->val_str(&buf); if (unlikely(!v)) return my_errno= HA_ERR_CRASHED; // ... ...etc... uchar *ptr= (uchar*)v->ptr(), *end= ptr + v->length(); for (size_t i=0; i <= max_layer; i++) { if (unlikely(ptr >= end)) return my_errno= HA_ERR_CRASHED; size_t grefs= *ptr++; if (unlikely(ptr + grefs * gref_len() > end)) return my_errno= HA_ERR_CRASHED; neighbors[i].num= grefs; for (size_t j=0; j < grefs; j++, ptr+= gref_len()) neighbors[i].links[j]= ctx->get_node(ptr); } vec= vec_ptr; // must be done at the very end return 0; } void FVectorNode::push_neighbor(size_t layer, FVectorNode *other) { DBUG_ASSERT(neighbors[layer].num < ctx->max_neighbors(layer)); neighbors[layer].links[neighbors[layer].num++]= other; } size_t FVectorNode::tref_len() const { return ctx->tref_len; } size_t FVectorNode::gref_len() const { return ctx->gref_len; } uchar *FVectorNode::gref() const { return (uchar*)(this+1); } uchar *FVectorNode::tref() const { return gref() + gref_len(); } uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool) { *key_len= elem->gref_len(); return elem->gref(); } /* one visited node during the search. caches the distance to target */ struct Visited : public Sql_alloc { FVectorNode *node; const float distance_to_target; Visited(FVectorNode *n, float d) : node(n), distance_to_target(d) {} static int cmp(void *, const Visited* a, const Visited *b) { return a->distance_to_target < b->distance_to_target ? -1 : a->distance_to_target > b->distance_to_target ? 1 : 0; } }; /* a factory to create Visited and keep track of already seen nodes note that PatternedSimdBloomFilter works in blocks of 8 elements, so on insert they're accumulated in nodes[], on search the caller provides 8 addresses at once. we record 0x0 as "seen" so that the caller could pad the input with nullptr's */ class VisitedSet { MEM_ROOT *root; const FVector *target; PatternedSimdBloomFilter map; const FVectorNode *nodes[8]= {0,0,0,0,0,0,0,0}; size_t idx= 1; // to record 0 in the filter public: uint count= 0; VisitedSet(MEM_ROOT *root, const FVector *target, uint size) : root(root), target(target), map(size, 0.01f) {} Visited *create(FVectorNode *node) { auto *v= new (root) Visited(node, node->distance_to(target)); insert(node); count++; return v; } void insert(const FVectorNode *n) { nodes[idx++]= n; if (idx == 8) flush(); } void flush() { if (idx) map.Insert(nodes); idx=0; } uint8_t seen(FVectorNode **nodes) { return map.Query(nodes); } }; /* selects best neighbors from the list of candidates plus one extra candidate one extra candidate is specified separately to avoid appending it to the Neighborhood candidates, which might be already at its max size. */ static int select_neighbors(MHNSW_Context *ctx, TABLE *graph, size_t layer, FVectorNode &target, const Neighborhood &candidates, FVectorNode *extra_candidate, size_t max_neighbor_connections) { Queue pq; // working queue if (pq.init(10000, false, Visited::cmp)) return my_errno= HA_ERR_OUT_OF_MEM; MEM_ROOT * const root= graph->in_use->mem_root; auto discarded= (Visited**)my_safe_alloca(sizeof(Visited**)*max_neighbor_connections); size_t discarded_num= 0; Neighborhood &neighbors= target.neighbors[layer]; for (size_t i=0; i < candidates.num; i++) { FVectorNode *node= candidates.links[i]; if (int err= node->load(graph)) return err; pq.push(new (root) Visited(node, node->distance_to(target.vec))); } if (extra_candidate) pq.push(new (root) Visited(extra_candidate, extra_candidate->distance_to(target.vec))); DBUG_ASSERT(pq.elements()); neighbors.num= 0; while (pq.elements() && neighbors.num < max_neighbor_connections) { Visited *vec= pq.pop(); FVectorNode * const node= vec->node; const float target_dista= vec->distance_to_target / alpha; bool discard= false; for (size_t i=0; i < neighbors.num; i++) if ((discard= node->distance_to(neighbors.links[i]->vec) < target_dista)) break; if (!discard) target.push_neighbor(layer, node); else if (discarded_num + neighbors.num < max_neighbor_connections) discarded[discarded_num++]= vec; } for (size_t i=0; i < discarded_num && neighbors.num < max_neighbor_connections; i++) target.push_neighbor(layer, discarded[i]->node); my_safe_afree(discarded, sizeof(Visited**)*max_neighbor_connections); return 0; } int FVectorNode::save(TABLE *graph) { DBUG_ASSERT(vec); DBUG_ASSERT(neighbors); restore_record(graph, s->default_values); graph->field[FIELD_LAYER]->store(max_layer, false); if (deleted) graph->field[FIELD_TREF]->set_null(); else { graph->field[FIELD_TREF]->set_notnull(); graph->field[FIELD_TREF]->store_binary(tref(), tref_len()); } graph->field[FIELD_VEC]->store_binary(vec->data(), FVector::data_size(ctx->vec_len)); size_t total_size= 0; for (size_t i=0; i <= max_layer; i++) total_size+= 1 + gref_len() * neighbors[i].num; uchar *neighbor_blob= static_cast(my_safe_alloca(total_size)); uchar *ptr= neighbor_blob; for (size_t i= 0; i <= max_layer; i++) { *ptr++= (uchar)(neighbors[i].num); for (size_t j= 0; j < neighbors[i].num; j++, ptr+= gref_len()) memcpy(ptr, neighbors[i].links[j]->gref(), gref_len()); } graph->field[FIELD_NEIGHBORS]->store_binary(neighbor_blob, total_size); int err; if (stored) { if (!(err= graph->file->ha_rnd_pos(graph->record[1], gref()))) { err= graph->file->ha_update_row(graph->record[1], graph->record[0]); if (err == HA_ERR_RECORD_IS_THE_SAME) err= 0; } } else { err= graph->file->ha_write_row(graph->record[0]); graph->file->position(graph->record[0]); memcpy(gref(), graph->file->ref, gref_len()); stored= true; ctx->cache_node(this); } my_safe_afree(neighbor_blob, total_size); return err; } static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph, size_t layer, FVectorNode *node) { const uint max_neighbors= ctx->max_neighbors(layer); // it seems that one could update nodes in the gref order // to avoid InnoDB deadlocks, but it produces no noticeable effect for (size_t i=0; i < node->neighbors[layer].num; i++) { FVectorNode *neigh= node->neighbors[layer].links[i]; Neighborhood &neighneighbors= neigh->neighbors[layer]; if (neighneighbors.num < max_neighbors) neigh->push_neighbor(layer, node); else if (int err= select_neighbors(ctx, graph, layer, *neigh, neighneighbors, node, max_neighbors)) return err; if (int err= neigh->save(graph)) return err; } return 0; } static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, Neighborhood *start_nodes, uint result_size, size_t layer, Neighborhood *result, bool construction) { DBUG_ASSERT(start_nodes->num > 0); result->num= 0; MEM_ROOT * const root= graph->in_use->mem_root; Queue candidates, best; bool skip_deleted; uint ef= result_size; if (construction) { skip_deleted= false; if (ef > 1) ef= std::max(ef_construction, ef); } else { skip_deleted= layer == 0; if (ef > 1 || layer == 0) ef= std::max(THDVAR(graph->in_use, min_limit), ef); } // WARNING! heuristic here const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer)); const uint est_size= static_cast(est_heuristic * std::pow(ef, ctx->ef_power)); VisitedSet visited(root, target, est_size); candidates.init(10000, false, Visited::cmp); best.init(ef, true, Visited::cmp); DBUG_ASSERT(start_nodes->num <= result_size); for (size_t i=0; i < start_nodes->num; i++) { Visited *v= visited.create(start_nodes->links[i]); candidates.push(v); if (skip_deleted && v->node->deleted) continue; best.push(v); } float furthest_best= best.is_empty() ? FLT_MAX : best.top()->distance_to_target * generosity; while (candidates.elements()) { const Visited &cur= *candidates.pop(); if (cur.distance_to_target > furthest_best && best.is_full()) break; // All possible candidates are worse than what we have visited.flush(); Neighborhood &neighbors= cur.node->neighbors[layer]; FVectorNode **links= neighbors.links, **end= links + neighbors.num; for (; links < end; links+= 8) { uint8_t res= visited.seen(links); if (res == 0xff) continue; for (size_t i= 0; i < 8; i++) { if (res & (1 << i)) continue; if (int err= links[i]->load(graph)) return err; Visited *v= visited.create(links[i]); if (!best.is_full()) { candidates.push(v); if (skip_deleted && v->node->deleted) continue; best.push(v); furthest_best= best.top()->distance_to_target * generosity; } else if (v->distance_to_target < furthest_best) { candidates.safe_push(v); if (skip_deleted && v->node->deleted) continue; if (v->distance_to_target < best.top()->distance_to_target) { best.replace_top(v); furthest_best= best.top()->distance_to_target * generosity; } } } } } if (ef > 1 && visited.count*2 > est_size) { double ef_power= std::log(visited.count*2/est_heuristic) / std::log(ef); set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok } while (best.elements() > result_size) best.pop(); result->num= best.elements(); for (FVectorNode **links= result->links + result->num; best.elements();) *--links= best.pop()->node; return 0; } static int bad_value_on_insert(Field *f) { my_error(ER_TRUNCATED_WRONG_VALUE_FOR_FIELD, MYF(0), "vector", "...", f->table->s->db.str, f->table->s->table_name.str, f->field_name.str, f->table->in_use->get_stmt_da()->current_row_for_warning()); return my_errno= HA_ERR_GENERIC; } int mhnsw_insert(TABLE *table, KEY *keyinfo) { THD *thd= table->in_use; TABLE *graph= table->hlindex; MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); Field *vec_field= keyinfo->key_part->field; String buf, *res= vec_field->val_str(&buf); MHNSW_Context *ctx; /* metadata are checked on open */ DBUG_ASSERT(graph); DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR); DBUG_ASSERT(keyinfo->usable_key_parts == 1); DBUG_ASSERT(vec_field->binary()); DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT); DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL DBUG_ASSERT(table->file->ref_length <= graph->field[FIELD_TREF]->field_length); // XXX returning an error here will rollback the insert in InnoDB // but in MyISAM the row will stay inserted, making the index out of sync: // invalid vector values are present in the table but cannot be found // via an index. The easiest way to fix it is with a VECTOR(N) type if (res->length() == 0 || res->length() % 4) return bad_value_on_insert(vec_field); table->file->position(table->record[0]); int err= MHNSW_Context::acquire(&ctx, table, true); SCOPE_EXIT([ctx, table](){ ctx->release(table); }); if (err) { if (err != HA_ERR_END_OF_FILE) return err; // First insert! ctx->set_lengths(res->length()); FVectorNode *target= new (ctx->alloc_node()) FVectorNode(ctx, table->file->ref, 0, res->ptr()); if (!((err= target->save(graph)))) ctx->start= target; return err; } if (ctx->byte_len != res->length()) return bad_value_on_insert(vec_field); MEM_ROOT_SAVEPOINT memroot_sv; root_make_savepoint(thd->mem_root, &memroot_sv); SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); }); const size_t max_found= ctx->max_neighbors(0); Neighborhood candidates, start_nodes; candidates.init(thd->alloc(max_found + 7), max_found); start_nodes.init(thd->alloc(max_found + 7), max_found); start_nodes.links[start_nodes.num++]= ctx->start; const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M); double log= -std::log(my_rnd(&thd->rand)) * NORMALIZATION_FACTOR; const uint8_t max_layer= start_nodes.links[0]->max_layer; uint8_t target_layer= std::min(static_cast(std::floor(log)), max_layer + 1); int cur_layer; FVectorNode *target= new (ctx->alloc_node()) FVectorNode(ctx, table->file->ref, target_layer, res->ptr()); if (int err= graph->file->ha_rnd_init(0)) return err; SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); }); for (cur_layer= max_layer; cur_layer > target_layer; cur_layer--) { if (int err= search_layer(ctx, graph, target->vec, &start_nodes, 1, cur_layer, &candidates, false)) return err; std::swap(start_nodes, candidates); } for (; cur_layer >= 0; cur_layer--) { uint max_neighbors= ctx->max_neighbors(cur_layer); if (int err= search_layer(ctx, graph, target->vec, &start_nodes, max_neighbors, cur_layer, &candidates, true)) return err; if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates, 0, max_neighbors)) return err; std::swap(start_nodes, candidates); } if (int err= target->save(graph)) return err; if (target_layer > max_layer) ctx->start= target; for (cur_layer= target_layer; cur_layer >= 0; cur_layer--) { if (int err= update_second_degree_neighbors(ctx, graph, cur_layer, target)) return err; } dbug_tmp_restore_column_map(&table->read_set, old_map); return 0; } int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) { THD *thd= table->in_use; TABLE *graph= table->hlindex; auto *fun= static_cast(dist->real_item()); DBUG_ASSERT(fun); String buf, *res= fun->get_const_arg()->val_str(&buf); MHNSW_Context *ctx; if (int err= table->file->ha_rnd_init(0)) return err; int err= MHNSW_Context::acquire(&ctx, table, false); SCOPE_EXIT([ctx, table](){ ctx->release(table); }); if (err) return err; Neighborhood candidates, start_nodes; candidates.init(thd->alloc(limit + 7), limit); start_nodes.init(thd->alloc(limit + 7), limit); // one could put all max_layer nodes in start_nodes // but it has no effect on the recall or speed start_nodes.links[start_nodes.num++]= ctx->start; /* if the query vector is NULL or invalid, VEC_DISTANCE will return NULL, so the result is basically unsorted, we can return rows in any order. Let's use some hardcoded value here */ if (!res || ctx->byte_len != res->length()) { res= &buf; buf.alloc(ctx->byte_len); buf.length(ctx->byte_len); for (size_t i=0; i < ctx->vec_len; i++) ((float*)buf.ptr())[i]= i == 0; } const longlong max_layer= start_nodes.links[0]->max_layer; auto target= FVector::create(ctx->metric, thd->alloc(FVector::alloc_size(ctx->vec_len)), res->ptr(), res->length()); if (int err= graph->file->ha_rnd_init(0)) return err; SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); }); for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--) { if (int err= search_layer(ctx, graph, target, &start_nodes, 1, cur_layer, &candidates, false)) return err; std::swap(start_nodes, candidates); } if (int err= search_layer(ctx, graph, target, &start_nodes, static_cast(limit), 0, &candidates, false)) return err; if (limit > candidates.num) limit= candidates.num; size_t context_size= limit * ctx->tref_len + sizeof(ulonglong); char *context= thd->alloc(context_size); graph->context= context; *(ulonglong*)context= limit; context+= context_size; for (size_t i=0; limit--; i++) { context-= ctx->tref_len; memcpy(context, candidates.links[i]->tref(), ctx->tref_len); } DBUG_ASSERT(context - sizeof(ulonglong) == graph->context); return mhnsw_read_next(table); } int mhnsw_read_next(TABLE *table) { uchar *ref= (uchar*)(table->hlindex->context); if (ulonglong *limit= (ulonglong*)ref) { ref+= sizeof(ulonglong) + (--*limit) * table->file->ref_length; return table->file->ha_rnd_pos(table->record[0], ref); } return my_errno= HA_ERR_END_OF_FILE; } void mhnsw_free(TABLE_SHARE *share) { TABLE_SHARE *graph_share= share->hlindex; if (!graph_share->hlindex_data) return; static_cast(graph_share->hlindex_data)->~MHNSW_Context(); graph_share->hlindex_data= 0; } int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo) { TABLE *graph= table->hlindex; handler *h= table->file; MHNSW_Context *ctx; bool use_ctx= !MHNSW_Context::acquire(&ctx, table, true); /* metadata are checked on open */ DBUG_ASSERT(graph); DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR); DBUG_ASSERT(keyinfo->usable_key_parts == 1); DBUG_ASSERT(h->ref_length <= graph->field[FIELD_TREF]->field_length); // target record: h->position(rec); graph->field[FIELD_TREF]->set_notnull(); graph->field[FIELD_TREF]->store_binary(h->ref, h->ref_length); uchar *key= (uchar*)alloca(graph->key_info[IDX_TREF].key_length); key_copy(key, graph->record[0], &graph->key_info[IDX_TREF], graph->key_info[IDX_TREF].key_length); if (int err= graph->file->ha_index_read_idx_map(graph->record[1], IDX_TREF, key, HA_WHOLE_KEY, HA_READ_KEY_EXACT)) return err; restore_record(graph, record[1]); graph->field[FIELD_TREF]->set_null(); if (int err= graph->file->ha_update_row(graph->record[1], graph->record[0])) return err; if (use_ctx) { graph->file->position(graph->record[0]); FVectorNode *node= ctx->get_node(graph->file->ref); node->deleted= true; ctx->release(table); } return 0; } int mhnsw_delete_all(TABLE *table, KEY *keyinfo, bool truncate) { TABLE *graph= table->hlindex; /* metadata are checked on open */ DBUG_ASSERT(graph); DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR); DBUG_ASSERT(keyinfo->usable_key_parts == 1); if (int err= truncate ? graph->file->truncate() : graph->file->delete_all_rows()) return err; MHNSW_Context *ctx; if (!MHNSW_Context::acquire(&ctx, table, true)) { ctx->reset(table->s); ctx->release(table); } return 0; } const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length) { const char templ[]="CREATE TABLE i ( " " layer tinyint not null, " " tref varbinary(%u), " " vec blob not null, " " neighbors blob not null, " " unique (tref), " " key (layer)) "; size_t len= sizeof(templ) + 32; char *s= thd->alloc(len); len= my_snprintf(s, len, templ, ref_length); return {s, len}; } bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist) { if (keyinfo->option_struct->metric == EUCLIDEAN) return dynamic_cast(dist) != NULL; return dynamic_cast(dist) != NULL; } /* Declare the plugin and index options */ ha_create_table_option mhnsw_index_options[]= { HA_IOPTION_SYSVAR("max_edges_per_node", M, max_edges_per_node), HA_IOPTION_SYSVAR("distance_function", metric, distance_function), HA_IOPTION_END }; st_plugin_int *mhnsw_plugin; static int mhnsw_init(void *p) { mhnsw_plugin= (st_plugin_int *)p; mhnsw_plugin->data= &MHNSW_Trx::tp; if (setup_transaction_participant(mhnsw_plugin)) return 1; return resolve_sysvar_table_options(mhnsw_index_options); } static int mhnsw_deinit(void *) { free_sysvar_table_options(mhnsw_index_options); return 0; } static struct st_mysql_storage_engine mhnsw_daemon= { MYSQL_DAEMON_INTERFACE_VERSION }; static struct st_mysql_sys_var *mhnsw_sys_vars[]= { MYSQL_SYSVAR(cache_size), MYSQL_SYSVAR(max_edges_per_node), MYSQL_SYSVAR(distance_function), MYSQL_SYSVAR(min_limit), NULL }; maria_declare_plugin(mhnsw) { MYSQL_DAEMON_PLUGIN, &mhnsw_daemon, "mhnsw", "MariaDB plc", "A plugin for mhnsw vector index algorithm", PLUGIN_LICENSE_GPL, mhnsw_init, mhnsw_deinit, 0x0100, NULL, mhnsw_sys_vars, "1.0", MariaDB_PLUGIN_MATURITY_STABLE } maria_declare_plugin_end;