From 049d839350f1ccb1c4b0c5ce04b2da3db71109d6 Mon Sep 17 00:00:00 2001 From: Sergei Golubchik Date: Wed, 17 Jul 2024 17:16:28 +0200 Subject: [PATCH] mhnsw: inter-statement shared cache * preserve the graph in memory between statements * keep it in a TABLE_SHARE, available for concurrent searches * nodes are generally read-only, walking the graph doesn't change them * distance to target is cached, calculated only once * SIMD-optimized bloom filter detects visited nodes * nodes are stored in an array, not List, to better utilize bloom filter * auto-adjusting heuristic to estimate the number of visited nodes (to configure the bloom filter) * many threads can concurrently walk the graph. MEM_ROOT and Hash_set are protected with a mutex, but walking doesn't need them * up to 8 threads can concurrently load nodes into the cache, nodes are partitioned into 8 mutexes (8 is chosen arbitrarily, might need tuning) * concurrent editing is not supported though * this is fine for MyISAM, TL_WRITE protects the TABLE_SHARE and the graph (note that TL_WRITE_CONCURRENT_INSERT is not allowed, because an INSERT into the main table means multiple UPDATEs in the graph) * InnoDB uses secondary transaction-level caches linked in a list in in thd->ha_data via a fake handlerton * on rollback the secondary cache is discarded, on commit nodes from the secondary cache are invalidated in the shared cache while it is exclusively locked * on savepoint rollback both caches are flushed. this can be improved in the future with a row visibility callback * graph size is controlled by @@mhnsw_cache_size, the cache is flushed when it reaches the threshold --- config.h.cmake | 1 + configure.cmake | 1 + include/my_sys.h | 6 + mysql-test/main/mysqld--help.result | 3 + mysql-test/main/vector_innodb.result | 92 ++ mysql-test/main/vector_innodb.test | 73 ++ .../r/sysvars_server_embedded,32bit.rdiff | 11 +- .../sys_vars/r/sysvars_server_embedded.result | 10 + .../r/sysvars_server_notembedded,32bit.rdiff | 13 +- .../r/sysvars_server_notembedded.result | 10 + mysys/my_alloc.c | 1 + sql/bloom_filters.h | 191 +++ sql/handler.h | 1 + sql/sql_base.cc | 4 + sql/sys_vars.cc | 5 + sql/table.cc | 5 + sql/table.h | 6 +- sql/vector_mhnsw.cc | 1070 ++++++++++++----- sql/vector_mhnsw.h | 3 + 19 files changed, 1179 insertions(+), 327 deletions(-) create mode 100644 mysql-test/main/vector_innodb.result create mode 100644 mysql-test/main/vector_innodb.test create mode 100644 sql/bloom_filters.h diff --git a/config.h.cmake b/config.h.cmake index 8b9fe5eee68..f99dfe85480 100644 --- a/config.h.cmake +++ b/config.h.cmake @@ -43,6 +43,7 @@ #cmakedefine HAVE_IA64INTRIN_H 1 #cmakedefine HAVE_IEEEFP_H 1 #cmakedefine HAVE_INTTYPES_H 1 +#cmakedefine HAVE_IMMINTRIN_H 1 #cmakedefine HAVE_KQUEUE 1 #cmakedefine HAVE_LIMITS_H 1 #cmakedefine HAVE_LINK_H 1 diff --git a/configure.cmake b/configure.cmake index 7f90e286155..e7d79880fe5 100644 --- a/configure.cmake +++ b/configure.cmake @@ -187,6 +187,7 @@ CHECK_INCLUDE_FILES (fpu_control.h HAVE_FPU_CONTROL_H) CHECK_INCLUDE_FILES (grp.h HAVE_GRP_H) CHECK_INCLUDE_FILES (ieeefp.h HAVE_IEEEFP_H) CHECK_INCLUDE_FILES (inttypes.h HAVE_INTTYPES_H) +CHECK_INCLUDE_FILES (immintrin.h HAVE_IMMINTRIN_H) CHECK_INCLUDE_FILES (langinfo.h HAVE_LANGINFO_H) CHECK_INCLUDE_FILES (link.h HAVE_LINK_H) CHECK_INCLUDE_FILES (linux/unistd.h HAVE_LINUX_UNISTD_H) diff --git a/include/my_sys.h b/include/my_sys.h index a9bde7422bb..cf9df3e7fa2 100644 --- a/include/my_sys.h +++ b/include/my_sys.h @@ -947,6 +947,12 @@ extern LEX_STRING lex_string_casedn_root(MEM_ROOT *root, CHARSET_INFO *cs, const char *str, size_t length); +static inline size_t root_size(MEM_ROOT *root) +{ + size_t k = root->block_num >> 2; + return k * (k + 1) * 2 * root->block_size; +} + extern my_bool my_compress(uchar *, size_t *, size_t *); extern my_bool my_uncompress(uchar *, size_t , size_t *); extern uchar *my_compress_alloc(const uchar *packet, size_t *len, diff --git a/mysql-test/main/mysqld--help.result b/mysql-test/main/mysqld--help.result index 8b1110b50b0..98b91beef67 100644 --- a/mysql-test/main/mysqld--help.result +++ b/mysql-test/main/mysqld--help.result @@ -708,6 +708,8 @@ The following specify which files/extra groups are read (specified before remain Unused. Deprecated, will be removed in a future release. --metadata-locks-hash-instances=# Unused. Deprecated, will be removed in a future release. + --mhnsw-cache-size=# + Size of the cache for the MHNSW vector index --mhnsw-max-edges-per-node=# Larger values means slower INSERT, larger index size and higher memory consumption, but better search results @@ -1830,6 +1832,7 @@ max-write-lock-count 18446744073709551615 memlock FALSE metadata-locks-cache-size 1024 metadata-locks-hash-instances 8 +mhnsw-cache-size 16777216 mhnsw-max-edges-per-node 6 mhnsw-min-limit 20 min-examined-row-limit 0 diff --git a/mysql-test/main/vector_innodb.result b/mysql-test/main/vector_innodb.result new file mode 100644 index 00000000000..4c1a718fbae --- /dev/null +++ b/mysql-test/main/vector_innodb.result @@ -0,0 +1,92 @@ +create table t1 (id int auto_increment primary key, v blob not null, vector index (v)) engine=innodb; +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `v` blob NOT NULL, + PRIMARY KEY (`id`), + VECTOR KEY `v` (`v`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_uca1400_ai_ci +insert t1 (v) values +(x'106d263fdf68ba3eb08d533f97d46e3fd1e1ec3edc4c123f984c563f621a233f'), +(x'd55bee3c56eb9e3e84e3093f838dce3eb7cd653fe32d7d3f12de133c5715d23e'), +(x'fcd5553f3822443f5dae413f2593493f7777363f5f7f113ebf12373d4d145a3f'), +(x'7493093fd9a27d3e9b13783f8c66653f0bd7d23e50db983d251b013f1dba133f'), +(x'2e30373fae331a3eba94153ee32bce3e3311b33d5bc75d3f6c25653eb769113f'), +(x'381d5f3f2781de3e4f011f3f9353483f9bb37e3edd622d3eabecb63ec246953e'), +(x'4ee5dc3e214b103f0e7e583f5f36473e79d7823ea872ec3e3ab2913d1b84433f'), +(x'8826243f7d20f03e5135593f83ba653e44572d3fa87e8e3e943e0e3f649a293f'), +(x'3859ac3e7d21823ed3f5753fc79c143e61d39c3cee39ba3eb0b0133e815c173f'), +(x'cff0d93c32941e3f64b22a3f1e4f083f4ea2563fbff4a63e12a4703f6c824b3f'); +start transaction; +insert t1 values +(30, x'f8e2413ed4ff773fef8b893eba487b3febee3f3f9e6f693f5961fd3ee479303d'); +savepoint foo; +insert t1 values +(31, x'6129683f90fe1f3e1437bc3ed8c8f63dd141033f21e3a93e54346c3f8c4e043f'), +(32, x'1ec8b83d398c4d3f2efb463f23947a3fa1a5093fdde6303e5580413f51569b3e'); +rollback to savepoint foo; +insert t1 values +(33, x'86d1003d4262033f8086713ffc4a633e317e933c4dce013d9c4d573fca83b93e'); +commit; +start transaction; +insert t1 values +(40, x'71046a3e85329b3e05240e3f45c9283f1847363f98d47d3f4224b73d487b613f'), +(41, x'71046a3e85329b3e05240e3f45c9283f1847363f98d47d3f4224b73d487b613f'); +rollback; +select id,vec_distance(v, x'c923e33dc0da313fe7c7983e526b3d3fde63963e6eaf3a3f27fa133fe27a583f') d from t1 order by d limit 5; +id d +10 0.8856208347761952 +1 0.9381363209273885 +30 1.0162643974895857 +7 1.026397313888122 +5 1.0308161006949719 +select id,vec_distance(v, x'754b5f3ea2312b3fc169f43e4604883e1d20173e8dd7443f421b703fb11e0d3e') d from t1 order by d limit 5; +id d +33 0.9477554826856 +30 1.111405427702547 +1 1.1154613877616022 +10 1.118630286292343 +8 1.1405733350751739 +create table t2 (id int auto_increment primary key, v blob not null, vector index (v)) engine=innodb; +insert t2 (v) values +(x'45cf153f830a313f7a0a113fb1ff533f47a1533fcf9e6e3f'), +(x'4b311d3fdd82423f35ba7d3fa041223dfd7db03e72d5833e'), +(x'f0d4123f6fc1833ea30a483fd9649d3cb94d733f4574a63d'), +(x'7ff8a53bf68e4a3e66e3563f214dea3e63372f3ec24d513f'), +(x'4709683f0d44473f8a045f3f40f3693df7f1303fdb98b73e'), +(x'09de2b3f5db80d3fb4405f3f64aadc3ecfa6183f823c733f'), +(x'a93a143f7f71e33d0cde5c3ff106373fd6f6233fc1f4fc3e'), +(x'11236e3de44a0d3f8241023d44d8383f2f70733f44d65c3f'), +(x'b5e47c3f35d3413fad8a533d5945133f66dbf33d92c6103f'); +start transaction; +insert t1 values +(50, x'acae183f56ddc43e5093983d280df53e6fa2093f79c01a3eb1591f3f423a0e3d'), +(51, x'6285303f42ef6e3f355e313f3e96a53e70959b3edd720b3ec07f733e5bc8603f'); +insert t2 values +(20, x'58dc7d3fc9feaa3e19e26b3f31820c3f93070b3fc4e36e3f'), +(21, x'35e05d3f18e8513fb81a3d3f8acf7d3e794a1d3c72f9613f'); +commit; +select id,vec_distance(v, x'1f4d053f7056493f937da03dd8c97a3f220cbb3c926c1c3facca213ec0618a3e') d from t1 order by d limit 5; +id d +6 0.9309383181777582 +5 0.9706304662574956 +30 0.98144492002831 +50 1.079862635421575 +51 1.2403734530917931 +select id,vec_distance(v, x'f618663f256be73e62cd453f8bcdbf3e16ae503c3858313f') d from t2 order by d limit 5; +id d +21 0.43559180321379337 +20 0.6435053022072372 +6 0.6942000623336242 +2 0.7971622099055623 +9 0.8298589136476077 +drop table t1, t2; +# +# MDEV-34989 After selecting from empty table with vector key the next insert hangs +# +create table t (v blob not null, vector key(v)) engine=InnoDB; +select vec_distance(v, x'B047263C9F87233fcfd27e3eae493e3f0329f43e') as e from t order by e limit 1; +e +insert into t values (x'B047263C9F87233fcfd27e3eae493e3f0329f43e'); +drop table t; diff --git a/mysql-test/main/vector_innodb.test b/mysql-test/main/vector_innodb.test new file mode 100644 index 00000000000..b743d64f71e --- /dev/null +++ b/mysql-test/main/vector_innodb.test @@ -0,0 +1,73 @@ +source include/have_innodb.inc; + +create table t1 (id int auto_increment primary key, v blob not null, vector index (v)) engine=innodb; +show create table t1; +# print unpack("H*",pack("f*",map{rand}1..8)) +insert t1 (v) values + (x'106d263fdf68ba3eb08d533f97d46e3fd1e1ec3edc4c123f984c563f621a233f'), + (x'd55bee3c56eb9e3e84e3093f838dce3eb7cd653fe32d7d3f12de133c5715d23e'), + (x'fcd5553f3822443f5dae413f2593493f7777363f5f7f113ebf12373d4d145a3f'), + (x'7493093fd9a27d3e9b13783f8c66653f0bd7d23e50db983d251b013f1dba133f'), + (x'2e30373fae331a3eba94153ee32bce3e3311b33d5bc75d3f6c25653eb769113f'), + (x'381d5f3f2781de3e4f011f3f9353483f9bb37e3edd622d3eabecb63ec246953e'), + (x'4ee5dc3e214b103f0e7e583f5f36473e79d7823ea872ec3e3ab2913d1b84433f'), + (x'8826243f7d20f03e5135593f83ba653e44572d3fa87e8e3e943e0e3f649a293f'), + (x'3859ac3e7d21823ed3f5753fc79c143e61d39c3cee39ba3eb0b0133e815c173f'), + (x'cff0d93c32941e3f64b22a3f1e4f083f4ea2563fbff4a63e12a4703f6c824b3f'); + +### savepoints and rollbacks: +start transaction; +insert t1 values + (30, x'f8e2413ed4ff773fef8b893eba487b3febee3f3f9e6f693f5961fd3ee479303d'); +savepoint foo; +insert t1 values + (31, x'6129683f90fe1f3e1437bc3ed8c8f63dd141033f21e3a93e54346c3f8c4e043f'), + (32, x'1ec8b83d398c4d3f2efb463f23947a3fa1a5093fdde6303e5580413f51569b3e'); +rollback to savepoint foo; +insert t1 values + (33, x'86d1003d4262033f8086713ffc4a633e317e933c4dce013d9c4d573fca83b93e'); +commit; +start transaction; +insert t1 values + (40, x'71046a3e85329b3e05240e3f45c9283f1847363f98d47d3f4224b73d487b613f'), + (41, x'71046a3e85329b3e05240e3f45c9283f1847363f98d47d3f4224b73d487b613f'); +rollback; + +select id,vec_distance(v, x'c923e33dc0da313fe7c7983e526b3d3fde63963e6eaf3a3f27fa133fe27a583f') d from t1 order by d limit 5; +select id,vec_distance(v, x'754b5f3ea2312b3fc169f43e4604883e1d20173e8dd7443f421b703fb11e0d3e') d from t1 order by d limit 5; + +### two indexes in one transaction: +create table t2 (id int auto_increment primary key, v blob not null, vector index (v)) engine=innodb; + +insert t2 (v) values + (x'45cf153f830a313f7a0a113fb1ff533f47a1533fcf9e6e3f'), + (x'4b311d3fdd82423f35ba7d3fa041223dfd7db03e72d5833e'), + (x'f0d4123f6fc1833ea30a483fd9649d3cb94d733f4574a63d'), + (x'7ff8a53bf68e4a3e66e3563f214dea3e63372f3ec24d513f'), + (x'4709683f0d44473f8a045f3f40f3693df7f1303fdb98b73e'), + (x'09de2b3f5db80d3fb4405f3f64aadc3ecfa6183f823c733f'), + (x'a93a143f7f71e33d0cde5c3ff106373fd6f6233fc1f4fc3e'), + (x'11236e3de44a0d3f8241023d44d8383f2f70733f44d65c3f'), + (x'b5e47c3f35d3413fad8a533d5945133f66dbf33d92c6103f'); + +start transaction; +insert t1 values + (50, x'acae183f56ddc43e5093983d280df53e6fa2093f79c01a3eb1591f3f423a0e3d'), + (51, x'6285303f42ef6e3f355e313f3e96a53e70959b3edd720b3ec07f733e5bc8603f'); +insert t2 values + (20, x'58dc7d3fc9feaa3e19e26b3f31820c3f93070b3fc4e36e3f'), + (21, x'35e05d3f18e8513fb81a3d3f8acf7d3e794a1d3c72f9613f'); +commit; + +select id,vec_distance(v, x'1f4d053f7056493f937da03dd8c97a3f220cbb3c926c1c3facca213ec0618a3e') d from t1 order by d limit 5; +select id,vec_distance(v, x'f618663f256be73e62cd453f8bcdbf3e16ae503c3858313f') d from t2 order by d limit 5; + +drop table t1, t2; + +--echo # +--echo # MDEV-34989 After selecting from empty table with vector key the next insert hangs +--echo # +create table t (v blob not null, vector key(v)) engine=InnoDB; +select vec_distance(v, x'B047263C9F87233fcfd27e3eae493e3f0329f43e') as e from t order by e limit 1; +insert into t values (x'B047263C9F87233fcfd27e3eae493e3f0329f43e'); +drop table t; diff --git a/mysql-test/suite/sys_vars/r/sysvars_server_embedded,32bit.rdiff b/mysql-test/suite/sys_vars/r/sysvars_server_embedded,32bit.rdiff index 4ce2102da27..2069ee9b06f 100644 --- a/mysql-test/suite/sys_vars/r/sysvars_server_embedded,32bit.rdiff +++ b/mysql-test/suite/sys_vars/r/sysvars_server_embedded,32bit.rdiff @@ -602,7 +602,16 @@ VARIABLE_COMMENT Unused NUMERIC_MIN_VALUE 1 NUMERIC_MAX_VALUE 1024 -@@ -2174,7 +2174,7 @@ READ_ONLY YES +@@ -2167,7 +2167,7 @@ VARIABLE_SCOPE GLOBAL + VARIABLE_TYPE BIGINT UNSIGNED + VARIABLE_COMMENT Size of the cache for the MHNSW vector index + NUMERIC_MIN_VALUE 1048576 +-NUMERIC_MAX_VALUE 18446744073709551615 ++NUMERIC_MAX_VALUE 4294967295 + NUMERIC_BLOCK_SIZE 1 + ENUM_VALUE_LIST NULL + READ_ONLY NO +@@ -2204,7 +2204,7 @@ READ_ONLY NO COMMAND_LINE_ARGUMENT REQUIRED VARIABLE_NAME MIN_EXAMINED_ROW_LIMIT VARIABLE_SCOPE SESSION diff --git a/mysql-test/suite/sys_vars/r/sysvars_server_embedded.result b/mysql-test/suite/sys_vars/r/sysvars_server_embedded.result index 76924f9351e..a63c229b966 100644 --- a/mysql-test/suite/sys_vars/r/sysvars_server_embedded.result +++ b/mysql-test/suite/sys_vars/r/sysvars_server_embedded.result @@ -2182,6 +2182,16 @@ NUMERIC_BLOCK_SIZE 1 ENUM_VALUE_LIST NULL READ_ONLY YES COMMAND_LINE_ARGUMENT REQUIRED +VARIABLE_NAME MHNSW_CACHE_SIZE +VARIABLE_SCOPE GLOBAL +VARIABLE_TYPE BIGINT UNSIGNED +VARIABLE_COMMENT Size of the cache for the MHNSW vector index +NUMERIC_MIN_VALUE 1048576 +NUMERIC_MAX_VALUE 18446744073709551615 +NUMERIC_BLOCK_SIZE 1 +ENUM_VALUE_LIST NULL +READ_ONLY NO +COMMAND_LINE_ARGUMENT REQUIRED VARIABLE_NAME MHNSW_MAX_EDGES_PER_NODE VARIABLE_SCOPE SESSION VARIABLE_TYPE INT UNSIGNED diff --git a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded,32bit.rdiff b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded,32bit.rdiff index 2860aa2519d..16eb033849e 100644 --- a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded,32bit.rdiff +++ b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded,32bit.rdiff @@ -1,3 +1,5 @@ +diff --git a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result +index 0906f942121..1521ce1a728 100644 --- a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result +++ b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result @@ -44,7 +44,7 @@ READ_ONLY NO @@ -611,7 +613,16 @@ VARIABLE_COMMENT Unused NUMERIC_MIN_VALUE 1 NUMERIC_MAX_VALUE 1024 -@@ -2384,7 +2384,7 @@ READ_ONLY YES +@@ -2377,7 +2377,7 @@ VARIABLE_SCOPE GLOBAL + VARIABLE_TYPE BIGINT UNSIGNED + VARIABLE_COMMENT Size of the cache for the MHNSW vector index + NUMERIC_MIN_VALUE 1048576 +-NUMERIC_MAX_VALUE 18446744073709551615 ++NUMERIC_MAX_VALUE 4294967295 + NUMERIC_BLOCK_SIZE 1 + ENUM_VALUE_LIST NULL + READ_ONLY NO +@@ -2414,7 +2414,7 @@ READ_ONLY NO COMMAND_LINE_ARGUMENT REQUIRED VARIABLE_NAME MIN_EXAMINED_ROW_LIMIT VARIABLE_SCOPE SESSION diff --git a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result index 405fce5d7c1..db4636a4a52 100644 --- a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result +++ b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result @@ -2392,6 +2392,16 @@ NUMERIC_BLOCK_SIZE 1 ENUM_VALUE_LIST NULL READ_ONLY YES COMMAND_LINE_ARGUMENT REQUIRED +VARIABLE_NAME MHNSW_CACHE_SIZE +VARIABLE_SCOPE GLOBAL +VARIABLE_TYPE BIGINT UNSIGNED +VARIABLE_COMMENT Size of the cache for the MHNSW vector index +NUMERIC_MIN_VALUE 1048576 +NUMERIC_MAX_VALUE 18446744073709551615 +NUMERIC_BLOCK_SIZE 1 +ENUM_VALUE_LIST NULL +READ_ONLY NO +COMMAND_LINE_ARGUMENT REQUIRED VARIABLE_NAME MHNSW_MAX_EDGES_PER_NODE VARIABLE_SCOPE SESSION VARIABLE_TYPE INT UNSIGNED diff --git a/mysys/my_alloc.c b/mysys/my_alloc.c index e6e58a2795f..29c6692f60d 100644 --- a/mysys/my_alloc.c +++ b/mysys/my_alloc.c @@ -324,6 +324,7 @@ void *alloc_root(MEM_ROOT *mem_root, size_t length) size_t alloced_length; /* Increase block size over time if there is a lot of mallocs */ + /* when changing this logic, update root_size() to match */ block_size= (MY_ALIGN(mem_root->block_size, ROOT_MIN_BLOCK_SIZE) * (mem_root->block_num >> 2)- MALLOC_OVERHEAD); get_size= length + ALIGN_SIZE(sizeof(USED_MEM)); diff --git a/sql/bloom_filters.h b/sql/bloom_filters.h new file mode 100644 index 00000000000..934697f7d97 --- /dev/null +++ b/sql/bloom_filters.h @@ -0,0 +1,191 @@ +/* +MIT License + +Copyright (c) 2023 Sasha Krassovsky + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +// https://save-buffer.github.io/bloom_filter.html + +#pragma once +#include +#include +#include +#ifdef HAVE_IMMINTRIN_H +#include +#endif + +template +struct PatternedSimdBloomFilter +{ + PatternedSimdBloomFilter(int n, float eps) : n(n), epsilon(eps) + { + m = ComputeNumBits(); + int log_num_blocks = my_bit_log2_uint32(m) + 1 - rotate_bits; + num_blocks = (1ULL << log_num_blocks); + bv.resize(num_blocks); + } + + uint32_t ComputeNumBits() + { + double bits_per_val = -1.44 * std::log2(epsilon); + return std::max(512, static_cast(bits_per_val * n + 0.5)); + } + +#if __GNUC__ > 7 && defined(HAVE_IMMINTRIN_H) + __attribute__ ((target ("avx2,avx,fma"))) + __m256i CalcHash(__m256i vecData) + { + // (almost) xxHash parallel version, 64bit input, 64bit output, seed=0 + static constexpr __m256i rotl48={ + 0x0504030201000706ULL, 0x0D0C0B0A09080F0EULL, + 0x1514131211101716ULL, 0x1D1C1B1A19181F1EULL + }; + static constexpr __m256i rotl24={ + 0x0201000706050403ULL, 0x0A09080F0E0D0C0BULL, + 0x1211101716151413ULL, 0x1A19181F1E1D1C1BULL, + }; + static constexpr uint64_t prime_mx2= 0x9FB21C651E98DF25ULL; + static constexpr uint64_t bitflip= 0xC73AB174C5ECD5A2ULL; + __m256i step1= _mm256_xor_si256(vecData, _mm256_set1_epi64x(bitflip)); + __m256i step2= _mm256_shuffle_epi8(step1, rotl48); + __m256i step3= _mm256_shuffle_epi8(step1, rotl24); + __m256i step4= _mm256_xor_si256(step1, _mm256_xor_si256(step2, step3)); + __m256i step5= _mm256_mul_epi32(step4, _mm256_set1_epi64x(prime_mx2)); + __m256i step6= _mm256_srli_epi64(step5, 35); + __m256i step7= _mm256_add_epi64(step6, _mm256_set1_epi64x(8)); + __m256i step8= _mm256_xor_si256(step5, step7); + __m256i step9= _mm256_mul_epi32(step8, _mm256_set1_epi64x(prime_mx2)); + return _mm256_xor_si256(step9, _mm256_srli_epi64(step9, 28)); + } + + __attribute__ ((target ("avx2,avx,fma"))) + __m256i GetBlockIdx(__m256i vecHash) + { + __m256i vecNumBlocksMask = _mm256_set1_epi64x(num_blocks - 1); + __m256i vecBlockIdx = _mm256_srli_epi64(vecHash, mask_idx_bits + rotate_bits); + return _mm256_and_si256(vecBlockIdx, vecNumBlocksMask); + } + + __attribute__ ((target ("avx2,avx,fma"))) + __m256i ConstructMask(__m256i vecHash) + { + __m256i vecMaskIdxMask = _mm256_set1_epi64x((1 << mask_idx_bits) - 1); + __m256i vecMaskMask = _mm256_set1_epi64x((1ull << bits_per_mask) - 1); + __m256i vec64 = _mm256_set1_epi64x(64); + + __m256i vecMaskIdx = _mm256_and_si256(vecHash, vecMaskIdxMask); + __m256i vecMaskByteIdx = _mm256_srli_epi64(vecMaskIdx, 3); + __m256i vecMaskBitIdx = _mm256_and_si256(vecMaskIdx, _mm256_set1_epi64x(0x7)); + __m256i vecRawMasks = _mm256_i64gather_epi64((const longlong *)masks, vecMaskByteIdx, 1); + __m256i vecUnrotated = _mm256_and_si256(_mm256_srlv_epi64(vecRawMasks, vecMaskBitIdx), vecMaskMask); + + __m256i vecRotation = _mm256_and_si256(_mm256_srli_epi64(vecHash, mask_idx_bits), _mm256_set1_epi64x((1 << rotate_bits) - 1)); + __m256i vecShiftUp = _mm256_sllv_epi64(vecUnrotated, vecRotation); + __m256i vecShiftDown = _mm256_srlv_epi64(vecUnrotated, _mm256_sub_epi64(vec64, vecRotation)); + return _mm256_or_si256(vecShiftDown, vecShiftUp); + } + + __attribute__ ((target ("avx2,avx,fma"))) + void Insert(const T **data) + { + __m256i vecDataA = _mm256_loadu_si256(reinterpret_cast<__m256i *>(data + 0)); + __m256i vecDataB = _mm256_loadu_si256(reinterpret_cast<__m256i *>(data + 4)); + + __m256i vecHashA= CalcHash(vecDataA); + __m256i vecHashB= CalcHash(vecDataB); + + __m256i vecMaskA = ConstructMask(vecHashA); + __m256i vecMaskB = ConstructMask(vecHashB); + + __m256i vecBlockIdxA = GetBlockIdx(vecHashA); + __m256i vecBlockIdxB = GetBlockIdx(vecHashB); + + uint64_t block0 = _mm256_extract_epi64(vecBlockIdxA, 0); + uint64_t block1 = _mm256_extract_epi64(vecBlockIdxA, 1); + uint64_t block2 = _mm256_extract_epi64(vecBlockIdxA, 2); + uint64_t block3 = _mm256_extract_epi64(vecBlockIdxA, 3); + uint64_t block4 = _mm256_extract_epi64(vecBlockIdxB, 0); + uint64_t block5 = _mm256_extract_epi64(vecBlockIdxB, 1); + uint64_t block6 = _mm256_extract_epi64(vecBlockIdxB, 2); + uint64_t block7 = _mm256_extract_epi64(vecBlockIdxB, 3); + + bv[block0] |= _mm256_extract_epi64(vecMaskA, 0); + bv[block1] |= _mm256_extract_epi64(vecMaskA, 1); + bv[block2] |= _mm256_extract_epi64(vecMaskA, 2); + bv[block3] |= _mm256_extract_epi64(vecMaskA, 3); + bv[block4] |= _mm256_extract_epi64(vecMaskB, 0); + bv[block5] |= _mm256_extract_epi64(vecMaskB, 1); + bv[block6] |= _mm256_extract_epi64(vecMaskB, 2); + bv[block7] |= _mm256_extract_epi64(vecMaskB, 3); + } + + __attribute__ ((target ("avx2,avx,fma"))) + uint8_t Query(T **data) + { + __m256i vecDataA = _mm256_loadu_si256(reinterpret_cast<__m256i *>(data + 0)); + __m256i vecDataB = _mm256_loadu_si256(reinterpret_cast<__m256i *>(data + 4)); + + __m256i vecHashA= CalcHash(vecDataA); + __m256i vecHashB= CalcHash(vecDataB); + + __m256i vecMaskA = ConstructMask(vecHashA); + __m256i vecMaskB = ConstructMask(vecHashB); + + __m256i vecBlockIdxA = GetBlockIdx(vecHashA); + __m256i vecBlockIdxB = GetBlockIdx(vecHashB); + + __m256i vecBloomA = _mm256_i64gather_epi64(bv.data(), vecBlockIdxA, sizeof(longlong)); + __m256i vecBloomB = _mm256_i64gather_epi64(bv.data(), vecBlockIdxB, sizeof(longlong)); + __m256i vecCmpA = _mm256_cmpeq_epi64(_mm256_and_si256(vecMaskA, vecBloomA), vecMaskA); + __m256i vecCmpB = _mm256_cmpeq_epi64(_mm256_and_si256(vecMaskB, vecBloomB), vecMaskB); + uint32_t res_a = static_cast(_mm256_movemask_epi8(vecCmpA)); + uint32_t res_b = static_cast(_mm256_movemask_epi8(vecCmpB)); + uint64_t res_bytes = res_a | (static_cast(res_b) << 32); + uint8_t res_bits = static_cast(_mm256_movemask_epi8(_mm256_set1_epi64x(res_bytes)) & 0xff); + return res_bits; + } +#endif + + int n; + float epsilon; + + uint64_t num_blocks; + uint32_t m; + // calculated from the upstream MaskTable and hard-coded + static constexpr int log_num_masks = 10; + static constexpr int bits_per_mask = 57; + const uint8_t masks[136]= {0x00, 0x04, 0x01, 0x04, 0x00, 0x20, 0x01, 0x00, + 0x00, 0x02, 0x08, 0x00, 0x02, 0x42, 0x00, 0x00, 0x04, 0x00, 0x00, 0x84, + 0x80, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x21, 0x00, 0x08, 0x00, 0x14, + 0x00, 0x00, 0x40, 0x00, 0x10, 0x00, 0xa8, 0x00, 0x00, 0x00, 0x00, 0x10, + 0x04, 0x40, 0x01, 0x00, 0x40, 0x00, 0x00, 0x08, 0x01, 0x02, 0x80, 0x00, + 0x00, 0x01, 0x00, 0x06, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x0c, 0x10, + 0x00, 0x10, 0x00, 0x00, 0x10, 0x08, 0x01, 0x10, 0x00, 0x00, 0x10, 0x20, + 0x00, 0x01, 0x20, 0x00, 0x02, 0x40, 0x00, 0x00, 0x02, 0x40, 0x01, 0x00, + 0x40, 0x00, 0x00, 0x0a, 0x00, 0x02, 0x01, 0x80, 0x00, 0x00, 0x10, 0x08, + 0x00, 0x06, 0x00, 0x04, 0x00, 0x00, 0x50, 0x00, 0x08, 0x10, 0x20, 0x00, + 0x00, 0x80, 0x00, 0x10, 0x10, 0x04, 0x04, 0x00, 0x00, 0x00, 0x20, 0x20, + 0x08, 0x08, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00}; + std::vector bv; + + static constexpr int mask_idx_bits = log_num_masks; + static constexpr int rotate_bits = 6; +}; diff --git a/sql/handler.h b/sql/handler.h index 3f9adbfc0f5..812e6943c9f 100644 --- a/sql/handler.h +++ b/sql/handler.h @@ -559,6 +559,7 @@ enum legacy_db_type { /* note these numerical values are fixed and can *not* be changed */ DB_TYPE_UNKNOWN=0, + DB_TYPE_HLINDEX_HELPER=6, DB_TYPE_HEAP=6, DB_TYPE_MYISAM=9, DB_TYPE_MRG_MYISAM=10, diff --git a/sql/sql_base.cc b/sql/sql_base.cc index 48098d1d034..ec5cbb01588 100644 --- a/sql/sql_base.cc +++ b/sql/sql_base.cc @@ -2401,6 +2401,10 @@ retry_share: my_error(ER_NOT_SEQUENCE, MYF(0), table_list->db.str, table_list->alias.str); DBUG_RETURN(true); } + /* hlindexes don't support concurrent insert */ + if (table->s->hlindexes() && + table_list->lock_type == TL_WRITE_CONCURRENT_INSERT) + table_list->lock_type= TL_WRITE_DEFAULT; DBUG_ASSERT(thd->locked_tables_mode || table->file->row_logging == 0); DBUG_RETURN(false); diff --git a/sql/sys_vars.cc b/sql/sys_vars.cc index 71bb662a7bc..ff49e2ffeac 100644 --- a/sql/sys_vars.cc +++ b/sql/sys_vars.cc @@ -55,6 +55,7 @@ #include "opt_trace_context.h" #include "log_event.h" #include "optimizer_defaults.h" +#include "vector_mhnsw.h" #ifdef WITH_PERFSCHEMA_STORAGE_ENGINE #include "../storage/perfschema/pfs_server.h" @@ -7463,3 +7464,7 @@ static Sys_var_uint Sys_mhnsw_max_edges_per_node( "memory consumption, but better search results", SESSION_VAR(mhnsw_max_edges_per_node), CMD_LINE(REQUIRED_ARG), VALID_RANGE(3, 200), DEFAULT(6), BLOCK_SIZE(1)); +static Sys_var_ulonglong Sys_mhnsw_cache_size( + "mhnsw_cache_size", "Size of the cache for the MHNSW vector index", + GLOBAL_VAR(mhnsw_cache_size), CMD_LINE(REQUIRED_ARG), + VALID_RANGE(1024*1024, SIZE_T_MAX), DEFAULT(16*1024*1024), BLOCK_SIZE(1)); diff --git a/sql/table.cc b/sql/table.cc index b2fcd07dbff..15f68b0fe69 100644 --- a/sql/table.cc +++ b/sql/table.cc @@ -50,6 +50,7 @@ #include "sql_delete.h" // class Sql_cmd_delete #include "rpl_rli.h" // class rpl_group_info #include "rpl_mi.h" // class Master_info +#include "vector_mhnsw.h" #ifdef WITH_WSREP #include "wsrep_schema.h" @@ -505,7 +506,10 @@ void TABLE_SHARE::destroy() delete sequence; if (hlindex) + { + mhnsw_free(this); hlindex->destroy(); + } /* The mutexes are initialized only for shares that are part of the TDC */ if (tmp_table == NO_TMP_TABLE) @@ -4795,6 +4799,7 @@ int closefrm(TABLE *table) if (table->hlindex) closefrm(table->hlindex); + if (table->db_stat) error=table->file->ha_close(); table->alias.free(); diff --git a/sql/table.h b/sql/table.h index e0fec1d89c8..aab8d3cfc06 100644 --- a/sql/table.h +++ b/sql/table.h @@ -743,7 +743,11 @@ struct TABLE_SHARE Virtual_column_info **check_constraints; uint *blob_field; /* Index to blobs in Field arrray*/ LEX_CUSTRING vcol_defs; /* definitions of generated columns */ - TABLE_SHARE *hlindex; + + union { + void *hlindex_data; /* for hlindex tables */ + TABLE_SHARE *hlindex; /* for normal tables */ + }; /* EITS statistics data from the last time the table was opened or ANALYZE diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index 1addb915e6c..a62361f26dc 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -19,6 +19,10 @@ #include "vector_mhnsw.h" #include "item_vectorfunc.h" #include +#include +#include "bloom_filters.h" + +ulonglong mhnsw_cache_size; // Algorithm parameters static constexpr float alpha = 1.1f; @@ -27,9 +31,6 @@ static constexpr uint ef_construction= 10; // SIMD definitions #define SIMD_word (256/8) #define SIMD_floats (SIMD_word/sizeof(float)) -// how many extra bytes we need to alloc to be able to convert -// sizeof(double) aligned memory to SIMD_word aligned -#define SIMD_margin (SIMD_word - sizeof(double)) enum Graph_table_fields { FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS @@ -39,116 +40,492 @@ enum Graph_table_indices { }; class MHNSW_Context; +class FVectorNode; -class FVector: public Sql_alloc +/* + One vector, an array of ctx->vec_len floats + + Aligned on 32-byte (SIMD_word) boundary for SIMD, vector lenght + is zero-padded to multiples of 8, for the same reason. +*/ +class FVector { public: - MHNSW_Context *ctx; - FVector(MHNSW_Context *ctx_, const void *vec_); + FVector(MHNSW_Context *ctx_, MEM_ROOT *root, const void *vec_); float *vec; protected: - FVector(MHNSW_Context *ctx_) : ctx(ctx_), vec(nullptr) {} - void make_vec(const void *vec_); + FVector() : vec(nullptr) {} }; +/* + An array of pointers to graph nodes + + It's mainly used to store all neighbors of a given node on a given layer. + + An array is fixed size, 2*M for the zero layer, M for other layers + see MHNSW_Context::max_neighbors(). + + Number of neighbors is zero-padded to multiples of 8 (for SIMD Bloom filter). + + Also used as a simply array of nodes in search_layer, the array size + then is defined by ef or efConstruction. +*/ +struct Neighborhood: public Sql_alloc +{ + FVectorNode **links; + size_t num; + FVectorNode **init(FVectorNode **ptr, size_t n) + { + num= 0; + links= ptr; + n= MY_ALIGN(n, 8); + bzero(ptr, n*sizeof(*ptr)); + return ptr + n; + } +}; + + +/* + One node in a graph = one row in the graph table + + stores a vector itself, ref (= position) in the graph (= hlindex) + table, a ref in the main table, and an array of Neighborhood's, one + per layer. + + It's lazily initialized, may know only gref, everything else is + loaded on demand. + + On the other hand, on INSERT the new node knows everything except + gref - which only becomes known after ha_write_row. + + Allocated on memroot in two chunks. One is the same size for all nodes + and stores FVectorNode object, gref, tref, and vector. The second + stores neighbors, all Neighborhood's together, its size depends + on the number of layers this node is on. + + There can be millions of nodes in the cache and the cache size + is constrained by mhnsw_cache_size, so every byte matters here +*/ +#pragma pack(push, 1) class FVectorNode: public FVector { private: - uchar *tref, *gref; - size_t max_layer; - - static uchar *gref_max; + MHNSW_Context *ctx; + float *make_vec(const void *v); int alloc_neighborhood(uint8_t layer); public: - List *neighbors= nullptr; + Neighborhood *neighbors= nullptr; + uint8_t max_layer; + bool stored; FVectorNode(MHNSW_Context *ctx_, const void *gref_); FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer, const void *vec_); float distance_to(const FVector &other) const; - int load(); - int load_from_record(); - int save(); - size_t get_tref_len() const; - uchar *get_tref() const { return tref; } - size_t get_gref_len() const; - uchar *get_gref() const { return gref; } + int load(TABLE *graph); + int load_from_record(TABLE *graph); + int save(TABLE *graph); + size_t tref_len() const; + size_t gref_len() const; + uchar *gref() const; + uchar *tref() const; + void push_neighbor(size_t layer, FVectorNode *v); static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool); }; +#pragma pack(pop) -// this assumes that 1) rows from graph table are never deleted, -// 2) and thus a ref for a new row is larger than refs of existing rows, -// thus we can treat the not-yet-inserted row as having max possible ref. -// oh, yes, and 3) 8 bytes ought to be enough for everyone -uchar *FVectorNode::gref_max=(uchar*)"\xff\xff\xff\xff\xff\xff\xff\xff"; +/* + Shared algorithm context. The graph. -class MHNSW_Context + Stored in TABLE_SHARE and on TABLE_SHARE::mem_root. + Stores the complete graph in MHNSW_Context::root, + The mapping gref->FVectorNode is in the node_cache. + Both root and node_cache are protected by a cache_lock, but it's + needed when loading nodes and is not used when the whole graph is in memory. + Graph can be traversed concurrently by different threads, as traversal + changes neither nodes nor the ctx. + Nodes can be loaded concurrently by different threads, this is protected + by a partitioned node_lock. + reference counter allows flushing the graph without interrupting + concurrent searches. + MyISAM automatically gets exclusive write access because of the TL_WRITE, + but InnoDB has to use a dedicated ctx->commit_lock for that +*/ +class MHNSW_Context : public Sql_alloc { - public: - MEM_ROOT root; - TABLE *table; - Field *vec_field; - size_t vec_len= 0; - size_t byte_len= 0; - uint err= 0; + std::atomic refcnt{0}; + mysql_mutex_t cache_lock; + mysql_mutex_t node_lock[8]; + void cache_internal(FVectorNode *node) + { + DBUG_ASSERT(node->stored); + node_cache.insert(node); + } + void *alloc_node_internal() + { + return alloc_root(&root, sizeof(FVectorNode) + gref_len + tref_len + + vec_len * sizeof(float) + SIMD_word - 1); + } + +protected: + MEM_ROOT root; Hash_set node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key}; - MHNSW_Context(TABLE *table, Field *vec_field) - : table(table), vec_field(vec_field) +public: + mysql_rwlock_t commit_lock; + size_t vec_len= 0; + size_t byte_len= 0; + Atomic_relaxed ef_power{0.6}; // for the bloom filter size heuristic + FVectorNode *start= 0; + const uint tref_len; + const uint gref_len; + const uint M; + + MHNSW_Context(TABLE *t) + : tref_len(t->file->ref_length), + gref_len(t->hlindex->file->ref_length), + M(t->in_use->variables.mhnsw_max_edges_per_node) { - init_alloc_root(PSI_INSTRUMENT_MEM, &root, 8192, 0, MYF(MY_THREAD_SPECIFIC)); + mysql_rwlock_init(PSI_INSTRUMENT_ME, &commit_lock); + mysql_mutex_init(PSI_INSTRUMENT_ME, &cache_lock, MY_MUTEX_INIT_FAST); + for (uint i=0; i < array_elements(node_lock); i++) + mysql_mutex_init(PSI_INSTRUMENT_ME, node_lock + i, MY_MUTEX_INIT_SLOW); + init_alloc_root(PSI_INSTRUMENT_MEM, &root, 1024*1024, 0, MYF(0)); } - ~MHNSW_Context() + virtual ~MHNSW_Context() { free_root(&root, MYF(0)); + mysql_rwlock_destroy(&commit_lock); + mysql_mutex_destroy(&cache_lock); + for (size_t i=0; i < array_elements(node_lock); i++) + mysql_mutex_destroy(node_lock + i); + } + + uint lock_node(FVectorNode *ptr) + { + ulong nr1= 1, nr2= 4; + my_hash_sort_bin(0, (uchar*)&ptr, sizeof(ptr), &nr1, &nr2); + uint ticket= nr1 % array_elements(node_lock); + mysql_mutex_lock(node_lock + ticket); + return ticket; + } + + void unlock_node(uint ticket) + { + mysql_mutex_unlock(node_lock + ticket); + } + + uint max_neighbors(size_t layer) const + { + return (layer ? 1 : 2) * M; // heuristic from the paper } - FVectorNode *get_node(const void *gref); void set_lengths(size_t len) { byte_len= len; vec_len= MY_ALIGN(byte_len/sizeof(float), SIMD_floats); } + + static int acquire(MHNSW_Context **ctx, TABLE *table, bool for_update); + static MHNSW_Context *get_from_share(TABLE_SHARE *share, TABLE *table); + + void reset_ctx(TABLE_SHARE *share) + { + mysql_mutex_lock(&share->LOCK_share); + if (static_cast(share->hlindex->hlindex_data) == this) + { + share->hlindex->hlindex_data= nullptr; + --refcnt; + } + mysql_mutex_unlock(&share->LOCK_share); + } + + void release(TABLE *table) + { + return release(table->file->has_transactions(), table->s); + } + + virtual void release(bool can_commit, TABLE_SHARE *share) + { + if (can_commit) + mysql_rwlock_unlock(&commit_lock); + if (root_size(&root) > mhnsw_cache_size) + reset_ctx(share); + if (--refcnt == 0) + this->~MHNSW_Context(); // XXX reuse + } + + FVectorNode *get_node(const void *gref) + { + mysql_mutex_lock(&cache_lock); + FVectorNode *node= node_cache.find(gref, gref_len); + if (!node) + { + node= new (alloc_node_internal()) FVectorNode(this, gref); + cache_internal(node); + } + mysql_mutex_unlock(&cache_lock); + return node; + } + + /* used on INSERT, gref isn't known, so cannot cache the node yet */ + void *alloc_node() + { + mysql_mutex_lock(&cache_lock); + auto p= alloc_node_internal(); + mysql_mutex_unlock(&cache_lock); + return p; + } + + /* explicitly cache the node after alloc_node() */ + void cache_node(FVectorNode *node) + { + mysql_mutex_lock(&cache_lock); + cache_internal(node); + mysql_mutex_unlock(&cache_lock); + } + + /* find the node without creating, only used on merging trx->ctx */ + FVectorNode *find_node(const void *gref) + { + mysql_mutex_lock(&cache_lock); + FVectorNode *node= node_cache.find(gref, gref_len); + mysql_mutex_unlock(&cache_lock); + return node; + } + + void *alloc_neighborhood(size_t max_layer) + { + mysql_mutex_lock(&cache_lock); + auto p= alloc_root(&root, sizeof(Neighborhood)*(max_layer+1) + + sizeof(FVectorNode*)*(MY_ALIGN(M, 4)*2 + MY_ALIGN(M,8)*max_layer)); + mysql_mutex_unlock(&cache_lock); + return p; + } }; -FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_) +/* + This is a non-shared context that exists within one transaction. + + At the end of the transaction it's either discarded (on rollback) + or merged into the shared ctx (on commit). + + trx's are stored in thd->ha_data[] in a single-linked list, + one instance of trx per TABLE_SHARE and allocated on the + thd->transaction->mem_root +*/ +class MHNSW_Trx : public MHNSW_Context { - make_vec(vec_); +public: + TABLE_SHARE *table_share; + bool list_of_nodes_is_lost= false; + MHNSW_Trx *next= nullptr; + + MHNSW_Trx(TABLE *table) : MHNSW_Context(table), table_share(table->s) {} + void reset_trx() + { + node_cache.clear(); + free_root(&root, MYF(0)); + start= 0; + list_of_nodes_is_lost= true; + } + void release(bool, TABLE_SHARE *) override + { + if (root_size(&root) > mhnsw_cache_size) + reset_trx(); + } + + static MHNSW_Trx *get_from_thd(THD *thd, TABLE *table); + + // it's okay in a transaction-local cache, there's no concurrent access + Hash_set &get_cache() { return node_cache; } + + /* fake handlerton to use thd->ha_data and to get notified of commits */ + static struct MHNSW_hton : public handlerton + { + MHNSW_hton() + { + db_type= DB_TYPE_HLINDEX_HELPER; + flags = HTON_NOT_USER_SELECTABLE | HTON_HIDDEN; + savepoint_offset= 0; + savepoint_set= [](handlerton *, THD *, void *){ return 0; }; + savepoint_rollback_can_release_mdl= [](handlerton *, THD *){ return true; }; + savepoint_rollback= do_savepoint_rollback; + commit= do_commit; + rollback= do_rollback; + } + static int do_commit(handlerton *, THD *thd, bool); + static int do_rollback(handlerton *, THD *thd, bool); + static int do_savepoint_rollback(handlerton *, THD *thd, void *); + } hton; +}; + +MHNSW_Trx::MHNSW_hton MHNSW_Trx::hton; + +int MHNSW_Trx::MHNSW_hton::do_savepoint_rollback(handlerton *, THD *thd, void *) +{ + for (auto trx= static_cast(thd_get_ha_data(thd, &hton)); + trx; trx= trx->next) + trx->reset_trx(); + return 0; } -void FVector::make_vec(const void *vec_) +int MHNSW_Trx::MHNSW_hton::do_rollback(handlerton *, THD *thd, bool) { - DBUG_ASSERT(ctx->vec_len); - vec= (float*)alloc_root(&ctx->root, - ctx->vec_len * sizeof(float) + SIMD_margin); - if (int off= ((intptr)vec) % SIMD_word) - vec += (SIMD_word - off) / sizeof(float); - memcpy(vec, vec_, ctx->byte_len); - for (size_t i=ctx->byte_len/sizeof(float); i < ctx->vec_len; i++) - vec[i]=0; + MHNSW_Trx *trx_next; + for (auto trx= static_cast(thd_get_ha_data(thd, &hton)); + trx; trx= trx_next) + { + trx_next= trx->next; + trx->~MHNSW_Trx(); + } + thd_set_ha_data(current_thd, &hton, nullptr); + return 0; +} + +int MHNSW_Trx::MHNSW_hton::do_commit(handlerton *, THD *thd, bool) +{ + MHNSW_Trx *trx_next; + for (auto trx= static_cast(thd_get_ha_data(thd, &hton)); + trx; trx= trx_next) + { + trx_next= trx->next; + auto ctx= MHNSW_Context::get_from_share(trx->table_share, nullptr); + if (ctx) + { + mysql_rwlock_wrlock(&ctx->commit_lock); + if (trx->list_of_nodes_is_lost) + ctx->reset_ctx(trx->table_share); + else + { + // consider copying nodes from trx to shared cache when it makes sense + // for ann_benchmarks it does not + // also, consider flushing only changed nodes (a flag in the node) + for (FVectorNode &from : trx->get_cache()) + if (FVectorNode *node= ctx->find_node(from.gref())) + node->vec= nullptr; + ctx->start= nullptr; + } + ctx->release(true, trx->table_share); + } + trx->~MHNSW_Trx(); + } + thd_set_ha_data(current_thd, &hton, nullptr); + return 0; +} + +MHNSW_Trx *MHNSW_Trx::get_from_thd(THD *thd, TABLE *table) +{ + auto trx= static_cast(thd_get_ha_data(thd, &hton)); + while (trx && trx->table_share != table->s) trx= trx->next; + if (!trx) + { + trx= new (&thd->transaction->mem_root) MHNSW_Trx(table); + trx->next= static_cast(thd_get_ha_data(thd, &hton)); + thd_set_ha_data(thd, &hton, trx); + if (!trx->next) + { + bool all= thd_test_options(thd, OPTION_NOT_AUTOCOMMIT | OPTION_BEGIN); + trans_register_ha(thd, all, &hton, 0); + } + } + return trx; +} + +MHNSW_Context *MHNSW_Context::get_from_share(TABLE_SHARE *share, TABLE *table) +{ + mysql_mutex_lock(&share->LOCK_share); + auto ctx= static_cast(share->hlindex->hlindex_data); + if (!ctx && table) + { + ctx= new (&share->hlindex->mem_root) MHNSW_Context(table); + if (!ctx) return nullptr; + share->hlindex->hlindex_data= ctx; + ctx->refcnt++; + } + if (ctx) + ctx->refcnt++; + mysql_mutex_unlock(&share->LOCK_share); + return ctx; +} + +int MHNSW_Context::acquire(MHNSW_Context **ctx, TABLE *table, bool for_update) +{ + TABLE *graph= table->hlindex; + THD *thd= table->in_use; + + if (table->file->has_transactions() && + (for_update || thd_get_ha_data(thd, &MHNSW_Trx::hton))) + *ctx= MHNSW_Trx::get_from_thd(thd, table); + else + { + *ctx= MHNSW_Context::get_from_share(table->s, table); + if (table->file->has_transactions()) + mysql_rwlock_rdlock(&(*ctx)->commit_lock); + } + + if ((*ctx)->start) + return 0; + + if (int err= graph->file->ha_index_init(IDX_LAYER, 1)) + return err; + + int err= graph->file->ha_index_last(graph->record[0]); + graph->file->ha_index_end(); + if (err) + return err; + + graph->file->position(graph->record[0]); + (*ctx)->set_lengths(graph->field[FIELD_VEC]->value_length()); + (*ctx)->start= (*ctx)->get_node(graph->file->ref); + return (*ctx)->start->load_from_record(graph); +} + +/* copy the vector, aligned and padded for SIMD */ +static float *make_vec(void *mem, const void *src, size_t src_len) +{ + auto dst= (float*)MY_ALIGN((intptr)mem, SIMD_word); + memcpy(dst, src, src_len); + const size_t start= src_len/sizeof(float); + for (size_t i= start; i < MY_ALIGN(start, SIMD_floats); i++) + dst[i]=0.0f; + return dst; +} + +FVector::FVector(MHNSW_Context *ctx, MEM_ROOT *root, const void *vec_) +{ + vec= make_vec(alloc_root(root, ctx->vec_len * sizeof(float) + SIMD_word - 1), + vec_, ctx->byte_len); +} + +float *FVectorNode::make_vec(const void *v) +{ + return ::make_vec(tref() + tref_len(), v, ctx->byte_len); } FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *gref_) - : FVector(ctx_), tref(nullptr) + : FVector(), ctx(ctx_), stored(true) { - gref= (uchar*)memdup_root(&ctx->root, gref_, get_gref_len()); + memcpy(gref(), gref_, gref_len()); } FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer, const void *vec_) - : FVector(ctx_, vec_), gref(gref_max) + : FVector(), ctx(ctx_), stored(false) { - tref= (uchar*)memdup_root(&ctx->root, tref_, get_tref_len()); + memset(gref(), 0xff, gref_len()); // important: larger than any real gref + memcpy(tref(), tref_, tref_len()); + vec= make_vec(vec_); + alloc_neighborhood(layer); } float FVectorNode::distance_to(const FVector &other) const { - const_cast(this)->load(); #if __GNUC__ > 7 typedef float v8f __attribute__((vector_size(SIMD_word))); v8f *p1= (v8f*)vec; @@ -167,295 +544,353 @@ float FVectorNode::distance_to(const FVector &other) const int FVectorNode::alloc_neighborhood(uint8_t layer) { - DBUG_ASSERT(!neighbors); + if (neighbors) + return 0; max_layer= layer; - neighbors= new (&ctx->root) List[layer+1]; + neighbors= (Neighborhood*)ctx->alloc_neighborhood(layer); + auto ptr= (FVectorNode**)(neighbors + (layer+1)); + for (size_t i= 0; i <= layer; i++) + ptr= neighbors[i].init(ptr, ctx->max_neighbors(i)); return 0; } -int FVectorNode::load() +int FVectorNode::load(TABLE *graph) { - DBUG_ASSERT(gref); - if (tref) + if (likely(vec)) return 0; - TABLE *graph= ctx->table->hlindex; - if ((ctx->err= graph->file->ha_rnd_pos(graph->record[0], gref))) - return ctx->err; - return load_from_record(); + DBUG_ASSERT(stored); + // trx: consider loading nodes from shared, when it makes sense + // for ann_benchmarks it does not + if (int err= graph->file->ha_rnd_pos(graph->record[0], gref())) + return err; + return load_from_record(graph); } -int FVectorNode::load_from_record() +int FVectorNode::load_from_record(TABLE *graph) { - TABLE *graph= ctx->table->hlindex; + DBUG_ASSERT(ctx->byte_len); + + uint ticket= ctx->lock_node(this); + SCOPE_EXIT([this, ticket](){ ctx->unlock_node(ticket); }); + + if (vec) + return 0; + String buf, *v= graph->field[FIELD_TREF]->val_str(&buf); - if (unlikely(!v || v->length() != get_tref_len())) - return ctx->err= HA_ERR_CRASHED; - tref= (uchar*)memdup_root(&ctx->root, v->ptr(), v->length()); + if (unlikely(!v || v->length() != tref_len())) + return my_errno= HA_ERR_CRASHED; + memcpy(tref(), v->ptr(), v->length()); v= graph->field[FIELD_VEC]->val_str(&buf); if (unlikely(!v)) - return ctx->err= HA_ERR_CRASHED; + return my_errno= HA_ERR_CRASHED; - DBUG_ASSERT(ctx->byte_len); if (v->length() != ctx->byte_len) - return ctx->err= HA_ERR_CRASHED; - make_vec(v->ptr()); + return my_errno= HA_ERR_CRASHED; + float *vec_ptr= make_vec(v->ptr()); longlong layer= graph->field[FIELD_LAYER]->val_int(); if (layer > 100) // 10e30 nodes at M=2, more at larger M's - return ctx->err= HA_ERR_CRASHED; + return my_errno= HA_ERR_CRASHED; - if (alloc_neighborhood(static_cast(layer))) - return ctx->err; + if (int err= alloc_neighborhood(static_cast(layer))) + return err; v= graph->field[FIELD_NEIGHBORS]->val_str(&buf); if (unlikely(!v)) - return ctx->err= HA_ERR_CRASHED; + return my_errno= HA_ERR_CRASHED; // ... ...etc... uchar *ptr= (uchar*)v->ptr(), *end= ptr + v->length(); for (size_t i=0; i <= max_layer; i++) { if (unlikely(ptr >= end)) - return ctx->err= HA_ERR_CRASHED; + return my_errno= HA_ERR_CRASHED; size_t grefs= *ptr++; - if (unlikely(ptr + grefs * get_gref_len() > end)) - return ctx->err= HA_ERR_CRASHED; - for (; grefs--; ptr+= get_gref_len()) - neighbors[i].push_back(ctx->get_node(ptr), &ctx->root); + if (unlikely(ptr + grefs * gref_len() > end)) + return my_errno= HA_ERR_CRASHED; + neighbors[i].num= grefs; + for (size_t j=0; j < grefs; j++, ptr+= gref_len()) + neighbors[i].links[j]= ctx->get_node(ptr); } + vec= vec_ptr; // must be done at the very end return 0; } -size_t FVectorNode::get_tref_len() const +void FVectorNode::push_neighbor(size_t layer, FVectorNode *other) { - return ctx->table->file->ref_length; + DBUG_ASSERT(neighbors[layer].num < ctx->max_neighbors(layer)); + neighbors[layer].links[neighbors[layer].num++]= other; } -size_t FVectorNode::get_gref_len() const -{ - return ctx->table->hlindex->file->ref_length; -} +size_t FVectorNode::tref_len() const { return ctx->tref_len; } +size_t FVectorNode::gref_len() const { return ctx->gref_len; } +uchar *FVectorNode::gref() const { return (uchar*)(this+1); } +uchar *FVectorNode::tref() const { return gref() + gref_len(); } uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool) { - *key_len= elem->get_gref_len(); - return elem->gref; + *key_len= elem->gref_len(); + return elem->gref(); } -FVectorNode *MHNSW_Context::get_node(const void *gref) +/* one visited node during the search. caches the distance to target */ +struct Visited : public Sql_alloc { - FVectorNode *node= node_cache.find(gref, table->hlindex->file->ref_length); - if (!node) + FVectorNode *node; + const float distance_to_target; + Visited(FVectorNode *n, float d) : node(n), distance_to_target(d) {} + static int cmp(void *, const Visited* a, const Visited *b) { - node= new (&root) FVectorNode(this, gref); - node_cache.insert(node); + return a->distance_to_target < b->distance_to_target ? -1 : + a->distance_to_target > b->distance_to_target ? 1 : 0; } - return node; -} +}; -static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNode *b) +/* + a factory to create Visited and keep track of already seen nodes + + note that PatternedSimdBloomFilter works in blocks of 8 elements, + so on insert they're accumulated in nodes[], on search the caller + provides 8 addresses at once. we record 0x0 as "seen" so that + the caller could pad the input with nullptr's +*/ +class VisitedSet { - float a_dist= a->distance_to(*target); - float b_dist= b->distance_to(*target); + MEM_ROOT *root; + const FVector ⌖ + PatternedSimdBloomFilter map; + const FVectorNode *nodes[8]= {0,0,0,0,0,0,0,0}; + size_t idx= 1; // to record 0 in the filter + public: + uint count= 0; + VisitedSet(MEM_ROOT *root, const FVector &target, uint size) : + root(root), target(target), map(size, 0.01f) {} + Visited *create(FVectorNode *node) + { + auto *v= new (root) Visited(node, node->distance_to(target)); + insert(node); + count++; + return v; + } + void insert(const FVectorNode *n) + { + nodes[idx++]= n; + if (idx == 8) flush(); + } + void flush() { + if (idx) map.Insert(nodes); + idx=0; + } + uint8_t seen(FVectorNode **nodes) { return map.Query(nodes); } +}; - if (a_dist < b_dist) - return -1; - if (a_dist > b_dist) - return 1; - return 0; -} -static int select_neighbors(MHNSW_Context *ctx, size_t layer, - FVectorNode &target, - const List &candidates_unsafe, +/* + selects best neighbors from the list of candidates plus one extra candidate + + one extra candidate is specified separately to avoid appending it to + the Neighborhood candidates, which might be already at its max size. +*/ +static int select_neighbors(MHNSW_Context *ctx, TABLE *graph, size_t layer, + FVectorNode &target, const Neighborhood &candidates, + FVectorNode *extra_candidate, size_t max_neighbor_connections) { - Hash_set visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key); - Queue pq; // working queue - Queue pq_discard; // queue for discarded candidates - /* - make a copy of candidates in case it's target.neighbors[layer]. - because we're going to modify the latter below - */ - List candidates= candidates_unsafe; - List &neighbors= target.neighbors[layer]; + Queue pq; // working queue - neighbors.empty(); + if (pq.init(10000, false, Visited::cmp)) + return my_errno= HA_ERR_OUT_OF_MEM; - if (pq.init(10000, 0, cmp_vec, &target) || - pq_discard.init(10000, 0, cmp_vec, &target)) - return ctx->err= HA_ERR_OUT_OF_MEM; + MEM_ROOT * const root= graph->in_use->mem_root; + auto discarded= (Visited**)my_safe_alloca(sizeof(Visited**)*max_neighbor_connections); + size_t discarded_num= 0; + Neighborhood &neighbors= target.neighbors[layer]; - for (const FVectorNode &candidate : candidates) + for (size_t i=0; i < candidates.num; i++) { - visited.insert(&candidate); - pq.push(&candidate); + FVectorNode *node= candidates.links[i]; + if (int err= node->load(graph)) + return err; + pq.push(new (root) Visited(node, node->distance_to(target))); } + if (extra_candidate) + pq.push(new (root) Visited(extra_candidate, extra_candidate->distance_to(target))); DBUG_ASSERT(pq.elements()); - neighbors.push_back(pq.pop(), &ctx->root); + neighbors.num= 0; - while (pq.elements() && neighbors.elements < max_neighbor_connections) + while (pq.elements() && neighbors.num < max_neighbor_connections) { - const FVectorNode *vec= pq.pop(); - const float target_dist= vec->distance_to(target); - const float target_dista= target_dist / alpha; + Visited *vec= pq.pop(); + FVectorNode * const node= vec->node; + const float target_dista= vec->distance_to_target / alpha; bool discard= false; - for (const FVectorNode &neigh : neighbors) - { - if ((discard= vec->distance_to(neigh) < target_dista)) + for (size_t i=0; i < neighbors.num; i++) + if ((discard= node->distance_to(*neighbors.links[i]) < target_dista)) break; - } if (!discard) - neighbors.push_back(vec, &ctx->root); - else if (pq_discard.elements() + neighbors.elements < max_neighbor_connections) - pq_discard.push(vec); + target.push_neighbor(layer, node); + else if (discarded_num + neighbors.num < max_neighbor_connections) + discarded[discarded_num++]= vec; } - while (pq_discard.elements() && neighbors.elements < max_neighbor_connections) - { - const FVectorNode *vec= pq_discard.pop(); - neighbors.push_back(vec, &ctx->root); - } + for (size_t i=0; i < discarded_num && neighbors.num < max_neighbor_connections; i++) + target.push_neighbor(layer, discarded[i]->node); + my_safe_afree(discarded, sizeof(Visited**)*max_neighbor_connections); return 0; } -int FVectorNode::save() +int FVectorNode::save(TABLE *graph) { - TABLE *graph= ctx->table->hlindex; - - DBUG_ASSERT(tref); DBUG_ASSERT(vec); DBUG_ASSERT(neighbors); restore_record(graph, s->default_values); graph->field[FIELD_LAYER]->store(max_layer, false); graph->field[FIELD_TREF]->set_notnull(); - graph->field[FIELD_TREF]->store_binary(tref, get_tref_len()); + graph->field[FIELD_TREF]->store_binary(tref(), tref_len()); graph->field[FIELD_VEC]->store_binary((uchar*)vec, ctx->byte_len); size_t total_size= 0; for (size_t i=0; i <= max_layer; i++) - total_size+= 1 + get_gref_len() * neighbors[i].elements; + total_size+= 1 + gref_len() * neighbors[i].num; uchar *neighbor_blob= static_cast(my_safe_alloca(total_size)); uchar *ptr= neighbor_blob; for (size_t i= 0; i <= max_layer; i++) { - *ptr++= (uchar)(neighbors[i].elements); - for (const auto &neigh: neighbors[i]) - { - memcpy(ptr, neigh.get_gref(), get_gref_len()); - ptr+= neigh.get_gref_len(); - } + *ptr++= (uchar)(neighbors[i].num); + for (size_t j= 0; j < neighbors[i].num; j++, ptr+= gref_len()) + memcpy(ptr, neighbors[i].links[j]->gref(), gref_len()); } graph->field[FIELD_NEIGHBORS]->store_binary(neighbor_blob, total_size); - if (gref != gref_max) + int err; + if (stored) { - ctx->err= graph->file->ha_rnd_pos(graph->record[1], gref); - if (!ctx->err) + if (!(err= graph->file->ha_rnd_pos(graph->record[1], gref()))) { - ctx->err= graph->file->ha_update_row(graph->record[1], graph->record[0]); - if (ctx->err == HA_ERR_RECORD_IS_THE_SAME) - ctx->err= 0; + err= graph->file->ha_update_row(graph->record[1], graph->record[0]); + if (err == HA_ERR_RECORD_IS_THE_SAME) + err= 0; } } else { - ctx->err= graph->file->ha_write_row(graph->record[0]); + err= graph->file->ha_write_row(graph->record[0]); graph->file->position(graph->record[0]); - gref= (uchar*)memdup_root(&ctx->root, graph->file->ref, get_gref_len()); + memcpy(gref(), graph->file->ref, gref_len()); + stored= true; + ctx->cache_node(this); } - my_safe_afree(neighbor_blob, total_size); - return ctx->err; + return err; } - -static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer, - uint max_neighbors, - const FVectorNode &node) +static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph, + size_t layer, FVectorNode *node) { - for (FVectorNode &neigh: node.neighbors[layer]) + const uint max_neighbors= ctx->max_neighbors(layer); + // it seems that one could update nodes in the gref order + // to avoid InnoDB deadlocks, but it produces no noticeable effect + for (size_t i=0; i < node->neighbors[layer].num; i++) { - List &neighneighbors= neigh.neighbors[layer]; - neighneighbors.push_back(&node, &ctx->root); - if (neighneighbors.elements > max_neighbors) - { - if (select_neighbors(ctx, layer, neigh, neighneighbors, max_neighbors)) - return ctx->err; - } - if (neigh.save()) - return ctx->err; + FVectorNode *neigh= node->neighbors[layer].links[i]; + Neighborhood &neighneighbors= neigh->neighbors[layer]; + if (neighneighbors.num < max_neighbors) + neigh->push_neighbor(layer, node); + else + if (int err= select_neighbors(ctx, graph, layer, *neigh, neighneighbors, + node, max_neighbors)) + return err; + if (int err= neigh->save(graph)) + return err; } return 0; } - -static int search_layer(MHNSW_Context *ctx, const FVector &target, - const List &start_nodes, - uint max_candidates_return, size_t layer, - List *result) +static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector &target, + Neighborhood *start_nodes, uint ef, size_t layer, + Neighborhood *result) { - DBUG_ASSERT(start_nodes.elements > 0); - DBUG_ASSERT(result->elements == 0); + DBUG_ASSERT(start_nodes->num > 0); + result->num= 0; - Queue candidates; - Queue best; - Hash_set visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key); + MEM_ROOT * const root= graph->in_use->mem_root; - candidates.init(10000, false, cmp_vec, &target); - best.init(max_candidates_return, true, cmp_vec, &target); + Queue candidates; + Queue best; - for (const FVectorNode &node : start_nodes) + // WARNING! heuristic here + const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer)); + const uint est_size= static_cast(est_heuristic * std::pow(ef, ctx->ef_power)); + VisitedSet visited(root, target, est_size); + + candidates.init(10000, false, Visited::cmp); + best.init(ef, true, Visited::cmp); + + for (size_t i=0; i < start_nodes->num; i++) { - candidates.push(&node); - if (best.elements() < max_candidates_return) - best.push(&node); - else if (node.distance_to(target) > best.top()->distance_to(target)) - best.replace_top(&node); - visited.insert(&node); + Visited *v= visited.create(start_nodes->links[i]); + candidates.push(v); + if (best.elements() < ef) + best.push(v); + else if (v->distance_to_target < best.top()->distance_to_target) + best.replace_top(v); } - float furthest_best= best.top()->distance_to(target); + float furthest_best= best.top()->distance_to_target; while (candidates.elements()) { - const FVectorNode &cur_vec= *candidates.pop(); - float cur_distance= cur_vec.distance_to(target); - if (cur_distance > furthest_best && best.elements() == max_candidates_return) - { - break; // All possible candidates are worse than what we have. - // Can't get better. - } + const Visited &cur= *candidates.pop(); + if (cur.distance_to_target > furthest_best && best.elements() == ef) + break; // All possible candidates are worse than what we have - for (const FVectorNode &neigh: cur_vec.neighbors[layer]) + visited.flush(); + + Neighborhood &neighbors= cur.node->neighbors[layer]; + FVectorNode **links= neighbors.links, **end= links + neighbors.num; + for (; links < end; links+= 8) { - if (visited.find(&neigh)) + uint8_t res= visited.seen(links); + if (res == 0xff) continue; - visited.insert(&neigh); - if (best.elements() < max_candidates_return) + for (size_t i= 0; i < 8; i++) { - candidates.push(&neigh); - best.push(&neigh); - furthest_best= best.top()->distance_to(target); - } - else if (neigh.distance_to(target) < furthest_best) - { - best.replace_top(&neigh); - candidates.push(&neigh); - furthest_best= best.top()->distance_to(target); + if (res & (1 << i)) + continue; + if (int err= links[i]->load(graph)) + return err; + Visited *v= visited.create(links[i]); + if (best.elements() < ef) + { + candidates.push(v); + best.push(v); + furthest_best= best.top()->distance_to_target; + } + else if (v->distance_to_target < furthest_best) + { + best.replace_top(v); + candidates.push(v); + furthest_best= best.top()->distance_to_target; + } } } } + if (ef > 1 && visited.count*2 > est_size) + { + double ef_power= std::log(visited.count*2/est_heuristic) / std::log(ef); + set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok + } - while (best.elements()) - result->push_front(best.pop(), &ctx->root); + result->num= best.elements(); + for (FVectorNode **links= result->links + result->num; best.elements();) + *--links= best.pop()->node; return 0; } @@ -466,7 +901,7 @@ static int bad_value_on_insert(Field *f) my_error(ER_TRUNCATED_WRONG_VALUE_FOR_FIELD, MYF(0), "vector", "...", f->table->s->db.str, f->table->s->table_name.str, f->field_name.str, f->table->in_use->get_stmt_da()->current_row_for_warning()); - return HA_ERR_GENERIC; + return my_errno= HA_ERR_GENERIC; } @@ -477,7 +912,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); Field *vec_field= keyinfo->key_part->field; String buf, *res= vec_field->val_str(&buf); - MHNSW_Context ctx(table, vec_field); + MHNSW_Context *ctx; /* metadata are checked on open */ DBUG_ASSERT(graph); @@ -495,92 +930,80 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) if (res->length() == 0 || res->length() % 4) return bad_value_on_insert(vec_field); - const double NORMALIZATION_FACTOR= 1 / std::log(thd->variables.mhnsw_max_edges_per_node); - table->file->position(table->record[0]); - if (int err= graph->file->ha_index_init(IDX_LAYER, 1)) - return err; - - ctx.err= graph->file->ha_index_last(graph->record[0]); - graph->file->ha_index_end(); - - if (ctx.err) + int err= MHNSW_Context::acquire(&ctx, table, true); + SCOPE_EXIT([ctx, table](){ ctx->release(table); }); + if (err) { - if (ctx.err != HA_ERR_END_OF_FILE) - return ctx.err; - ctx.err= 0; + if (err != HA_ERR_END_OF_FILE) + return err; // First insert! - ctx.set_lengths(res->length()); - FVectorNode target(&ctx, table->file->ref, 0, res->ptr()); - return target.save(); + ctx->set_lengths(res->length()); + FVectorNode *target= new (ctx->alloc_node()) + FVectorNode(ctx, table->file->ref, 0, res->ptr()); + if (!((err= target->save(graph)))) + ctx->start= target; + return err; } - longlong max_layer= graph->field[FIELD_LAYER]->val_int(); - - List candidates; - List start_nodes; - - graph->file->position(graph->record[0]); - FVectorNode *start_node= ctx.get_node(graph->file->ref); - - if (start_nodes.push_back(start_node, &ctx.root)) - return HA_ERR_OUT_OF_MEM; - - ctx.set_lengths(graph->field[FIELD_VEC]->value_length()); - if (int err= start_node->load_from_record()) - return err; - - if (ctx.byte_len != res->length()) + if (ctx->byte_len != res->length()) return bad_value_on_insert(vec_field); + MEM_ROOT_SAVEPOINT memroot_sv; + root_make_savepoint(thd->mem_root, &memroot_sv); + SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); }); + + Neighborhood candidates, start_nodes; + candidates.init(thd->alloc(ef_construction + 7), ef_construction); + start_nodes.init(thd->alloc(ef_construction + 7), ef_construction); + start_nodes.links[start_nodes.num++]= ctx->start; + + const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M); + double log= -std::log(my_rnd(&thd->rand)) * NORMALIZATION_FACTOR; + const uint8_t max_layer= start_nodes.links[0]->max_layer; + uint8_t target_layer= std::min(static_cast(std::floor(log)), max_layer + 1); + int cur_layer; + + FVectorNode *target= new (ctx->alloc_node()) + FVectorNode(ctx, table->file->ref, target_layer, res->ptr()); + if (int err= graph->file->ha_rnd_init(0)) return err; - SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); }); - double new_num= my_rnd(&thd->rand); - double log= -std::log(new_num) * NORMALIZATION_FACTOR; - longlong new_node_layer= std::min(std::floor(log), max_layer + 1); - longlong cur_layer; - - FVectorNode target(&ctx, table->file->ref, new_node_layer, res->ptr()); - - for (cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--) + for (cur_layer= max_layer; cur_layer > target_layer; cur_layer--) { - if (search_layer(&ctx, target, start_nodes, 1, cur_layer, &candidates)) - return ctx.err; - start_nodes= candidates; - candidates.empty(); + if (int err= search_layer(ctx, graph, *target, &start_nodes, 1, cur_layer, + &candidates)) + return err; + std::swap(start_nodes, candidates); } for (; cur_layer >= 0; cur_layer--) { - uint max_neighbors= (cur_layer == 0) // heuristics from the paper - ? thd->variables.mhnsw_max_edges_per_node * 2 - : thd->variables.mhnsw_max_edges_per_node; - if (search_layer(&ctx, target, start_nodes, ef_construction, cur_layer, - &candidates)) - return ctx.err; + uint max_neighbors= ctx->max_neighbors(cur_layer); + if (int err= search_layer(ctx, graph, *target, &start_nodes, + ef_construction, cur_layer, &candidates)) + return err; - if (select_neighbors(&ctx, cur_layer, target, candidates, max_neighbors)) - return ctx.err; - start_nodes= candidates; - candidates.empty(); + if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates, + 0, max_neighbors)) + return err; + std::swap(start_nodes, candidates); } - if (target.save()) - return ctx.err; + if (int err= target->save(graph)) + return err; - for (longlong cur_layer= new_node_layer; cur_layer >= 0; cur_layer--) + if (target_layer > max_layer) + ctx->start= target; + + for (cur_layer= target_layer; cur_layer >= 0; cur_layer--) { - uint max_neighbors= (cur_layer == 0) // heuristics from the paper - ? thd->variables.mhnsw_max_edges_per_node * 2 - : thd->variables.mhnsw_max_edges_per_node; - // XXX do only one ha_update_row() per node - if (update_second_degree_neighbors(&ctx, cur_layer, max_neighbors, target)) - return ctx.err; + if (int err= update_second_degree_neighbors(ctx, graph, cur_layer, target)) + return err; } dbug_tmp_restore_column_map(&table->read_set, old_map); @@ -593,79 +1016,68 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) { THD *thd= table->in_use; TABLE *graph= table->hlindex; - Field *vec_field= keyinfo->key_part->field; Item_func_vec_distance *fun= (Item_func_vec_distance *)dist; String buf, *res= fun->get_const_arg()->val_str(&buf); - handler *h= table->file; - MHNSW_Context ctx(table, vec_field); + MHNSW_Context *ctx; - if (int err= h->ha_rnd_init(0)) + if (int err= table->file->ha_rnd_init(0)) return err; - if (int err= graph->file->ha_index_init(0, 1)) + int err= MHNSW_Context::acquire(&ctx, table, false); + SCOPE_EXIT([ctx, table](){ ctx->release(table); }); + if (err) return err; - ctx.err= graph->file->ha_index_last(graph->record[0]); - graph->file->ha_index_end(); - if (ctx.err) - return ctx.err; + size_t ef= thd->variables.mhnsw_min_limit; - longlong max_layer= graph->field[FIELD_LAYER]->val_int(); - - List candidates; - List start_nodes; - - graph->file->position(graph->record[0]); - FVectorNode *start_node= ctx.get_node(graph->file->ref); + Neighborhood candidates, start_nodes; + candidates.init(thd->alloc(ef + 7), ef); + start_nodes.init(thd->alloc(ef + 7), ef); // one could put all max_layer nodes in start_nodes // but it has no effect on the recall or speed - if (start_nodes.push_back(start_node, &ctx.root)) - return HA_ERR_OUT_OF_MEM; - - ctx.set_lengths(graph->field[FIELD_VEC]->value_length()); - if (int err= start_node->load_from_record()) - return err; + start_nodes.links[start_nodes.num++]= ctx->start; /* if the query vector is NULL or invalid, VEC_DISTANCE will return NULL, so the result is basically unsorted, we can return rows in any order. For simplicity let's sort by the start_node. */ - if (!res || ctx.byte_len != res->length()) - (res= &buf)->set((char*)start_node->vec, ctx.byte_len, &my_charset_bin); + if (!res || ctx->byte_len != res->length()) + (res= &buf)->set((char*)start_nodes.links[0]->vec, ctx->byte_len, &my_charset_bin); + + const longlong max_layer= start_nodes.links[0]->max_layer; + FVector target(ctx, thd->mem_root, res->ptr()); if (int err= graph->file->ha_rnd_init(0)) return err; - SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); }); - FVector target(&ctx, res->ptr()); - - uint ef_search= thd->variables.mhnsw_min_limit; - for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--) { - if (search_layer(&ctx, target, start_nodes, 1, cur_layer, &candidates)) - return ctx.err; - start_nodes= candidates; - candidates.empty(); + if (int err= search_layer(ctx, graph, target, &start_nodes, 1, cur_layer, + &candidates)) + return err; + std::swap(start_nodes, candidates); } - if (search_layer(&ctx, target, start_nodes, ef_search, 0, &candidates)) - return ctx.err; + if (int err= search_layer(ctx, graph, target, &start_nodes, ef, 0, + &candidates)) + return err; - size_t context_size= limit * h->ref_length + sizeof(ulonglong); + if (limit > candidates.num) + limit= candidates.num; + size_t context_size= limit * ctx->tref_len + sizeof(ulonglong); char *context= thd->alloc(context_size); graph->context= context; *(ulonglong*)context= limit; context+= context_size; - while (limit--) + for (size_t i=0; limit--; i++) { - context-= h->ref_length; - memcpy(context, candidates.pop()->get_tref(), h->ref_length); + context-= ctx->tref_len; + memcpy(context, candidates.links[i]->tref(), ctx->tref_len); } DBUG_ASSERT(context - sizeof(ulonglong) == graph->context); @@ -680,7 +1092,17 @@ int mhnsw_read_next(TABLE *table) ref+= sizeof(ulonglong) + (--*limit) * table->file->ref_length; return table->file->ha_rnd_pos(table->record[0], ref); } - return HA_ERR_END_OF_FILE; + return my_errno= HA_ERR_END_OF_FILE; +} + +void mhnsw_free(TABLE_SHARE *share) +{ + TABLE_SHARE *graph_share= share->hlindex; + if (!graph_share->hlindex_data) + return; + + static_cast(graph_share->hlindex_data)->~MHNSW_Context(); + graph_share->hlindex_data= 0; } const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length) diff --git a/sql/vector_mhnsw.h b/sql/vector_mhnsw.h index 267303dc578..fccf463de54 100644 --- a/sql/vector_mhnsw.h +++ b/sql/vector_mhnsw.h @@ -25,3 +25,6 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length); int mhnsw_insert(TABLE *table, KEY *keyinfo); int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit); int mhnsw_read_next(TABLE *table); +void mhnsw_free(TABLE_SHARE *share); + +extern ulonglong mhnsw_cache_size;