vppinfra: AVX512 in clib_count_equal_*
Type: improvement
Change-Id: I8105d396cfc984e00cf5137bc57122510f5e6437
Signed-off-by: Damjan Marion <damarion@cisco.com>
diff --git a/src/vppinfra/vector/count_equal.h b/src/vppinfra/vector/count_equal.h
index 98770cf..a2aeecd 100644
--- a/src/vppinfra/vector/count_equal.h
+++ b/src/vppinfra/vector/count_equal.h
@@ -67,28 +67,62 @@
count = 0;
first = data[0];
-#if defined(CLIB_HAVE_VEC256)
+#if defined(CLIB_HAVE_VEC512)
+ u32x16 splat = u32x16_splat (first);
+ while (count + 15 < max_count)
+ {
+ u32 bmp;
+ bmp = u32x16_is_equal_mask (u32x16_load_unaligned (data), splat);
+ if (bmp != pow2_mask (16))
+ return count + count_trailing_zeros (~bmp);
+
+ data += 16;
+ count += 16;
+ }
+ if (count == max_count)
+ return count;
+ else
+ {
+ u32 mask = pow2_mask (max_count - count);
+ u32 bmp =
+ u32x16_is_equal_mask (u32x16_mask_load_zero (data, mask), splat);
+ return count + count_trailing_zeros (~bmp);
+ }
+#elif defined(CLIB_HAVE_VEC256)
u32x8 splat = u32x8_splat (first);
while (count + 7 < max_count)
{
- u64 bmp;
+ u32 bmp;
+#ifdef __AVX512F__
+ bmp = u32x8_is_equal_mask (u32x8_load_unaligned (data), splat);
+ if (bmp != pow2_mask (8))
+ return count + count_trailing_zeros (~bmp);
+#else
bmp = u8x32_msb_mask ((u8x32) (u32x8_load_unaligned (data) == splat));
if (bmp != 0xffffffff)
- {
- count += count_trailing_zeros (~bmp) / 4;
- return count;
- }
+ return count + count_trailing_zeros (~bmp) / 4;
+#endif
data += 8;
count += 8;
}
+ if (count == max_count)
+ return count;
+#if defined(CxLIB_HAVE_VEC256_MASK_LOAD_STORE)
+ else
+ {
+ u32 mask = pow2_mask (max_count - count);
+ u32 bmp = u32x8_is_equal_mask (u32x8_mask_load_zero (data, mask), splat);
+ return count + count_trailing_zeros (~bmp);
+ }
+#endif
#elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK)
u32x4 splat = u32x4_splat (first);
while (count + 3 < max_count)
{
u64 bmp;
bmp = u8x16_msb_mask ((u8x16) (u32x4_load_unaligned (data) == splat));
- if (bmp != 0xffff)
+ if (bmp != pow2_mask (4 * 4))
{
count += count_trailing_zeros (~bmp) / 4;
return count;
@@ -191,18 +225,50 @@
count = 0;
first = data[0];
-#if defined(CLIB_HAVE_VEC256)
+#if defined(CLIB_HAVE_VEC512)
+ u8x64 splat = u8x64_splat (first);
+ while (count + 63 < max_count)
+ {
+ u64 bmp;
+ bmp = u8x64_is_equal_mask (u8x64_load_unaligned (data), splat);
+ if (bmp != -1)
+ return count + count_trailing_zeros (~bmp);
+
+ data += 64;
+ count += 64;
+ }
+ if (count == max_count)
+ return count;
+#if defined(CLIB_HAVE_VEC512_MASK_LOAD_STORE)
+ else
+ {
+ u64 mask = pow2_mask (max_count - count);
+ u64 bmp = u8x64_is_equal_mask (u8x64_mask_load_zero (data, mask), splat);
+ return count + count_trailing_zeros (~bmp);
+ }
+#endif
+#elif defined(CLIB_HAVE_VEC256)
u8x32 splat = u8x32_splat (first);
while (count + 31 < max_count)
{
u64 bmp;
bmp = u8x32_msb_mask ((u8x32) (u8x32_load_unaligned (data) == splat));
if (bmp != 0xffffffff)
- return max_count;
+ return count + count_trailing_zeros (~bmp);
data += 32;
count += 32;
}
+ if (count == max_count)
+ return count;
+#if defined(CLIB_HAVE_VEC256_MASK_LOAD_STORE)
+ else
+ {
+ u32 mask = pow2_mask (max_count - count);
+ u64 bmp = u8x32_msb_mask (u8x32_mask_load_zero (data, mask) == splat);
+ return count + count_trailing_zeros (~bmp);
+ }
+#endif
#elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK)
u8x16 splat = u8x16_splat (first);
while (count + 15 < max_count)
@@ -210,10 +276,7 @@
u64 bmp;
bmp = u8x16_msb_mask ((u8x16) (u8x16_load_unaligned (data) == splat));
if (bmp != 0xffff)
- {
- count += count_trailing_zeros (~bmp);
- return count;
- }
+ return count + count_trailing_zeros (~bmp);
data += 16;
count += 16;
@@ -235,4 +298,5 @@
}
return count;
}
+
#endif
diff --git a/src/vppinfra/vector_avx512.h b/src/vppinfra/vector_avx512.h
index a82231a..1a5c252 100644
--- a/src/vppinfra/vector_avx512.h
+++ b/src/vppinfra/vector_avx512.h
@@ -301,6 +301,27 @@
_ (u64x8, u8, epu64, _mm512, __m512i)
#undef _
+#define _(t, m, e, p, it) \
+ static_always_inline m t##_is_not_equal_mask (t a, t b) \
+ { \
+ return p##_cmpneq_##e##_mask ((it) a, (it) b); \
+ }
+_ (u8x16, u16, epu8, _mm, __m128i)
+_ (u16x8, u8, epu16, _mm, __m128i)
+_ (u32x4, u8, epu32, _mm, __m128i)
+_ (u64x2, u8, epu64, _mm, __m128i)
+
+_ (u8x32, u32, epu8, _mm256, __m256i)
+_ (u16x16, u16, epu16, _mm256, __m256i)
+_ (u32x8, u8, epu32, _mm256, __m256i)
+_ (u64x4, u8, epu64, _mm256, __m256i)
+
+_ (u8x64, u64, epu8, _mm512, __m512i)
+_ (u16x32, u32, epu16, _mm512, __m512i)
+_ (u32x16, u16, epu32, _mm512, __m512i)
+_ (u64x8, u8, epu64, _mm512, __m512i)
+#undef _
+
#define _(f, t, fn, it) \
static_always_inline t t##_from_##f (f x) { return (t) fn ((it) x); }
_ (u16x16, u32x16, _mm512_cvtepi16_epi32, __m256i)