From 11a6c1b30a12c448ddfe05e1b818a6a228e90e43 Mon Sep 17 00:00:00 2001 From: Sergey Vojtovich Date: Thu, 28 Nov 2024 23:02:29 +0400 Subject: [PATCH] MDEV-34699 - mhnsw: support aarch64 SIMD instructions SIMD implementations of bloom filters and dot product calculation. A microbenchmark shows 1.7x dot product performance improvement compared to regular -O2/-O3 builds and 2.4x compared to builds with auto-vectorization disabled. Performance improvement (microbenchmark) for bloom filters is less exciting, within 10-30% ballpark depending on compiler options and load. Misc implementation notes: CalcHash: no _mm256_shuffle_epi8(), use explicit XOR/shift. CalcHash: no 64bit multiplication, do scalar multiplication. ConstructMask/Query: no _mm256_i64gather_epi64, access array elements explicitly. Query: no _mm256_movemask_epi8, accumulate bits manually. Closes #3671 --- sql/bloom_filters.h | 167 +++++++++++++++++++++++++++++++++++++++++++- sql/vector_mhnsw.cc | 44 ++++++++++++ 2 files changed, 208 insertions(+), 3 deletions(-) diff --git a/sql/bloom_filters.h b/sql/bloom_filters.h index 315673d04b3..0dafaf1b6f3 100644 --- a/sql/bloom_filters.h +++ b/sql/bloom_filters.h @@ -28,9 +28,18 @@ SOFTWARE. #include #include #include + +/* + Use gcc function multiversioning to optimize for a specific CPU with run-time + detection. Works only for x86, for other architectures we provide only one + implementation for now. +*/ +#define DEFAULT_IMPLEMENTATION +#if __GNUC__ > 7 +#ifdef __x86_64__ #ifdef HAVE_IMMINTRIN_H #include -#if __GNUC__ > 7 && defined __x86_64__ +#undef DEFAULT_IMPLEMENTATION #define DEFAULT_IMPLEMENTATION __attribute__ ((target ("default"))) #define AVX2_IMPLEMENTATION __attribute__ ((target ("avx2,avx,fma"))) #if __GNUC__ > 9 @@ -38,8 +47,11 @@ SOFTWARE. #endif #endif #endif -#ifndef DEFAULT_IMPLEMENTATION -#define DEFAULT_IMPLEMENTATION +#ifdef __aarch64__ +#include +#undef DEFAULT_IMPLEMENTATION +#define NEON_IMPLEMENTATION +#endif #endif template @@ -177,9 +189,157 @@ struct PatternedSimdBloomFilter basically, unnoticeable, well below the noise level */ #endif +#ifdef NEON_IMPLEMENTATION + uint64x2_t CalcHash(uint64x2_t vecData) + { + static constexpr uint64_t prime_mx2= 0x9FB21C651E98DF25ULL; + static constexpr uint64_t bitflip= 0xC73AB174C5ECD5A2ULL; + uint64x2_t step1= veorq_u64(vecData, vdupq_n_u64(bitflip)); + uint64x2_t step2= veorq_u64(vshrq_n_u64(step1, 48), vshlq_n_u64(step1, 16)); + uint64x2_t step3= veorq_u64(vshrq_n_u64(step1, 24), vshlq_n_u64(step1, 40)); + uint64x2_t step4= veorq_u64(step1, veorq_u64(step2, step3)); + uint64x2_t step5; + step5= vsetq_lane_u64(vgetq_lane_u64(step4, 0) * prime_mx2, step4, 0); + step5= vsetq_lane_u64(vgetq_lane_u64(step4, 1) * prime_mx2, step5, 1); + uint64x2_t step6= vshrq_n_u64(step5, 35); + uint64x2_t step7= vaddq_u64(step6, vdupq_n_u64(8)); + uint64x2_t step8= veorq_u64(step5, step7); + uint64x2_t step9; + step9= vsetq_lane_u64(vgetq_lane_u64(step8, 0) * prime_mx2, step8, 0); + step9= vsetq_lane_u64(vgetq_lane_u64(step8, 1) * prime_mx2, step9, 1); + return veorq_u64(step9, vshrq_n_u64(step9, 28)); + } + + uint64x2_t GetBlockIdx(uint64x2_t vecHash) + { + uint64x2_t vecNumBlocksMask= vdupq_n_u64(num_blocks - 1); + uint64x2_t vecBlockIdx= vshrq_n_u64(vecHash, mask_idx_bits + rotate_bits); + return vandq_u64(vecBlockIdx, vecNumBlocksMask); + } + + uint64x2_t ConstructMask(uint64x2_t vecHash) + { + uint64x2_t vecMaskIdxMask= vdupq_n_u64((1 << mask_idx_bits) - 1); + uint64x2_t vecMaskMask= vdupq_n_u64((1ull << bits_per_mask) - 1); + + uint64x2_t vecMaskIdx= vandq_u64(vecHash, vecMaskIdxMask); + uint64x2_t vecMaskByteIdx= vshrq_n_u64(vecMaskIdx, 3); + /* + Shift right in NEON is implemented as shift left by a negative value. + Do the negation here. + */ + int64x2_t vecMaskBitIdx= + vsubq_s64(vdupq_n_s64(0), + vreinterpretq_s64_u64(vandq_u64(vecMaskIdx, vdupq_n_u64(0x7)))); + uint64x2_t vecRawMasks= vdupq_n_u64(*reinterpret_cast + (masks + vgetq_lane_u64(vecMaskByteIdx, 0))); + vecRawMasks= vsetq_lane_u64(*reinterpret_cast + (masks + vgetq_lane_u64(vecMaskByteIdx, 1)), vecRawMasks, 1); + uint64x2_t vecUnrotated= + vandq_u64(vshlq_u64(vecRawMasks, vecMaskBitIdx), vecMaskMask); + + int64x2_t vecRotation= + vreinterpretq_s64_u64(vandq_u64(vshrq_n_u64(vecHash, mask_idx_bits), + vdupq_n_u64((1 << rotate_bits) - 1))); + uint64x2_t vecShiftUp= vshlq_u64(vecUnrotated, vecRotation); + uint64x2_t vecShiftDown= + vshlq_u64(vecUnrotated, vsubq_s64(vecRotation, vdupq_n_s64(64))); + return vorrq_u64(vecShiftDown, vecShiftUp); + } + + void Insert(const T **data) + { + uint64x2_t vecDataA= vld1q_u64(reinterpret_cast(data + 0)); + uint64x2_t vecDataB= vld1q_u64(reinterpret_cast(data + 2)); + uint64x2_t vecDataC= vld1q_u64(reinterpret_cast(data + 4)); + uint64x2_t vecDataD= vld1q_u64(reinterpret_cast(data + 6)); + + uint64x2_t vecHashA= CalcHash(vecDataA); + uint64x2_t vecHashB= CalcHash(vecDataB); + uint64x2_t vecHashC= CalcHash(vecDataC); + uint64x2_t vecHashD= CalcHash(vecDataD); + + uint64x2_t vecMaskA= ConstructMask(vecHashA); + uint64x2_t vecMaskB= ConstructMask(vecHashB); + uint64x2_t vecMaskC= ConstructMask(vecHashC); + uint64x2_t vecMaskD= ConstructMask(vecHashD); + + uint64x2_t vecBlockIdxA= GetBlockIdx(vecHashA); + uint64x2_t vecBlockIdxB= GetBlockIdx(vecHashB); + uint64x2_t vecBlockIdxC= GetBlockIdx(vecHashC); + uint64x2_t vecBlockIdxD= GetBlockIdx(vecHashD); + + uint64_t block0= vgetq_lane_u64(vecBlockIdxA, 0); + uint64_t block1= vgetq_lane_u64(vecBlockIdxA, 1); + uint64_t block2= vgetq_lane_u64(vecBlockIdxB, 0); + uint64_t block3= vgetq_lane_u64(vecBlockIdxB, 1); + uint64_t block4= vgetq_lane_u64(vecBlockIdxC, 0); + uint64_t block5= vgetq_lane_u64(vecBlockIdxC, 1); + uint64_t block6= vgetq_lane_u64(vecBlockIdxD, 0); + uint64_t block7= vgetq_lane_u64(vecBlockIdxD, 1); + + bv[block0]|= vgetq_lane_u64(vecMaskA, 0); + bv[block1]|= vgetq_lane_u64(vecMaskA, 1); + bv[block2]|= vgetq_lane_u64(vecMaskB, 0); + bv[block3]|= vgetq_lane_u64(vecMaskB, 1); + bv[block4]|= vgetq_lane_u64(vecMaskC, 0); + bv[block5]|= vgetq_lane_u64(vecMaskC, 1); + bv[block6]|= vgetq_lane_u64(vecMaskD, 0); + bv[block7]|= vgetq_lane_u64(vecMaskD, 1); + } + + uint8_t Query(T **data) + { + uint64x2_t vecDataA= vld1q_u64(reinterpret_cast(data + 0)); + uint64x2_t vecDataB= vld1q_u64(reinterpret_cast(data + 2)); + uint64x2_t vecDataC= vld1q_u64(reinterpret_cast(data + 4)); + uint64x2_t vecDataD= vld1q_u64(reinterpret_cast(data + 6)); + + uint64x2_t vecHashA= CalcHash(vecDataA); + uint64x2_t vecHashB= CalcHash(vecDataB); + uint64x2_t vecHashC= CalcHash(vecDataC); + uint64x2_t vecHashD= CalcHash(vecDataD); + + uint64x2_t vecMaskA= ConstructMask(vecHashA); + uint64x2_t vecMaskB= ConstructMask(vecHashB); + uint64x2_t vecMaskC= ConstructMask(vecHashC); + uint64x2_t vecMaskD= ConstructMask(vecHashD); + + uint64x2_t vecBlockIdxA= GetBlockIdx(vecHashA); + uint64x2_t vecBlockIdxB= GetBlockIdx(vecHashB); + uint64x2_t vecBlockIdxC= GetBlockIdx(vecHashC); + uint64x2_t vecBlockIdxD= GetBlockIdx(vecHashD); + + uint64x2_t vecBloomA= vdupq_n_u64(bv[vgetq_lane_u64(vecBlockIdxA, 0)]); + vecBloomA= vsetq_lane_u64(bv[vgetq_lane_u64(vecBlockIdxA, 1)], vecBloomA, 1); + uint64x2_t vecBloomB= vdupq_n_u64(bv[vgetq_lane_u64(vecBlockIdxB, 0)]); + vecBloomB= vsetq_lane_u64(bv[vgetq_lane_u64(vecBlockIdxB, 1)], vecBloomB, 1); + uint64x2_t vecBloomC= vdupq_n_u64(bv[vgetq_lane_u64(vecBlockIdxC, 0)]); + vecBloomC= vsetq_lane_u64(bv[vgetq_lane_u64(vecBlockIdxC, 1)], vecBloomC, 1); + uint64x2_t vecBloomD= vdupq_n_u64(bv[vgetq_lane_u64(vecBlockIdxD, 0)]); + vecBloomD= vsetq_lane_u64(bv[vgetq_lane_u64(vecBlockIdxD, 1)], vecBloomD, 1); + + uint64x2_t vecCmpA= vceqq_u64(vandq_u64(vecMaskA, vecBloomA), vecMaskA); + uint64x2_t vecCmpB= vceqq_u64(vandq_u64(vecMaskB, vecBloomB), vecMaskB); + uint64x2_t vecCmpC= vceqq_u64(vandq_u64(vecMaskC, vecBloomC), vecMaskC); + uint64x2_t vecCmpD= vceqq_u64(vandq_u64(vecMaskD, vecBloomD), vecMaskD); + + return + (vgetq_lane_u64(vecCmpA, 0) & 0x01) | + (vgetq_lane_u64(vecCmpA, 1) & 0x02) | + (vgetq_lane_u64(vecCmpB, 0) & 0x04) | + (vgetq_lane_u64(vecCmpB, 1) & 0x08) | + (vgetq_lane_u64(vecCmpC, 0) & 0x10) | + (vgetq_lane_u64(vecCmpC, 1) & 0x20) | + (vgetq_lane_u64(vecCmpD, 0) & 0x40) | + (vgetq_lane_u64(vecCmpD, 1) & 0x80); + } +#endif + /******************************************************** ********* non-SIMD fallback version ********************/ +#ifdef DEFAULT_IMPLEMENTATION uint64_t CalcHash_1(const T* data) { static constexpr uint64_t prime_mx2= 0x9FB21C651E98DF25ULL; @@ -240,6 +400,7 @@ struct PatternedSimdBloomFilter } return res_bits; } +#endif int n; float epsilon; diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index 449a7daf344..8fbeffb3a58 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -188,7 +188,50 @@ struct FVector } #endif + +/* + ARM NEON implementation. A microbenchmark shows 1.7x dot_product() performance + improvement compared to regular -O2/-O3 builds and 2.4x compared to builds + with auto-vectorization disabled. + + There seem to be no performance difference between vmull+vmull_high and + vmull+vmlal2_high implementations. +*/ + +#ifdef NEON_IMPLEMENTATION + static constexpr size_t NEON_bytes= 128 / 8; + static constexpr size_t NEON_dims= NEON_bytes / sizeof(int16_t); + + static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) + { + int64_t d= 0; + for (size_t i= 0; i < (len + NEON_dims - 1) / NEON_dims; i++) + { + int16x8_t p1= vld1q_s16(v1); + int16x8_t p2= vld1q_s16(v2); + d+= vaddlvq_s32(vmull_s16(vget_low_s16(p1), vget_low_s16(p2))) + + vaddlvq_s32(vmull_high_s16(p1, p2)); + v1+= NEON_dims; + v2+= NEON_dims; + } + return static_cast(d); + } + + static size_t alloc_size(size_t n) + { return alloc_header + MY_ALIGN(n * 2, NEON_bytes) + NEON_bytes - 1; } + + static FVector *align_ptr(void *ptr) + { return (FVector*) (MY_ALIGN(((intptr) ptr) + alloc_header, NEON_bytes) + - alloc_header); } + + void fix_tail(size_t vec_len) + { + bzero(dims + vec_len, (MY_ALIGN(vec_len, NEON_dims) - vec_len) * 2); + } +#endif + /************* no-SIMD default ******************************************/ +#ifdef DEFAULT_IMPLEMENTATION DEFAULT_IMPLEMENTATION static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) { @@ -206,6 +249,7 @@ struct FVector DEFAULT_IMPLEMENTATION void fix_tail(size_t) { } +#endif float distance_to(const FVector *other, size_t vec_len) const {