vppinfra: AVX512 in clib_count_equal_*

Type: improvement
Change-Id: I8105d396cfc984e00cf5137bc57122510f5e6437
Signed-off-by: Damjan Marion <damarion@cisco.com>
diff --git a/src/vppinfra/vector/count_equal.h b/src/vppinfra/vector/count_equal.h
index 98770cf..a2aeecd 100644
--- a/src/vppinfra/vector/count_equal.h
+++ b/src/vppinfra/vector/count_equal.h
@@ -67,28 +67,62 @@
   count = 0;
   first = data[0];
 
-#if defined(CLIB_HAVE_VEC256)
+#if defined(CLIB_HAVE_VEC512)
+  u32x16 splat = u32x16_splat (first);
+  while (count + 15 < max_count)
+    {
+      u32 bmp;
+      bmp = u32x16_is_equal_mask (u32x16_load_unaligned (data), splat);
+      if (bmp != pow2_mask (16))
+	return count + count_trailing_zeros (~bmp);
+
+      data += 16;
+      count += 16;
+    }
+  if (count == max_count)
+    return count;
+  else
+    {
+      u32 mask = pow2_mask (max_count - count);
+      u32 bmp =
+	u32x16_is_equal_mask (u32x16_mask_load_zero (data, mask), splat);
+      return count + count_trailing_zeros (~bmp);
+    }
+#elif defined(CLIB_HAVE_VEC256)
   u32x8 splat = u32x8_splat (first);
   while (count + 7 < max_count)
     {
-      u64 bmp;
+      u32 bmp;
+#ifdef __AVX512F__
+      bmp = u32x8_is_equal_mask (u32x8_load_unaligned (data), splat);
+      if (bmp != pow2_mask (8))
+	return count + count_trailing_zeros (~bmp);
+#else
       bmp = u8x32_msb_mask ((u8x32) (u32x8_load_unaligned (data) == splat));
       if (bmp != 0xffffffff)
-	{
-	  count += count_trailing_zeros (~bmp) / 4;
-	  return count;
-	}
+	return count + count_trailing_zeros (~bmp) / 4;
+#endif
 
       data += 8;
       count += 8;
     }
+  if (count == max_count)
+    return count;
+#if defined(CxLIB_HAVE_VEC256_MASK_LOAD_STORE)
+  else
+    {
+      u32 mask = pow2_mask (max_count - count);
+      u32 bmp = u32x8_is_equal_mask (u32x8_mask_load_zero (data, mask), splat);
+      return count + count_trailing_zeros (~bmp);
+    }
+#endif
 #elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK)
   u32x4 splat = u32x4_splat (first);
   while (count + 3 < max_count)
     {
       u64 bmp;
       bmp = u8x16_msb_mask ((u8x16) (u32x4_load_unaligned (data) == splat));
-      if (bmp != 0xffff)
+      if (bmp != pow2_mask (4 * 4))
 	{
 	  count += count_trailing_zeros (~bmp) / 4;
 	  return count;
@@ -191,18 +225,50 @@
   count = 0;
   first = data[0];
 
-#if defined(CLIB_HAVE_VEC256)
+#if defined(CLIB_HAVE_VEC512)
+  u8x64 splat = u8x64_splat (first);
+  while (count + 63 < max_count)
+    {
+      u64 bmp;
+      bmp = u8x64_is_equal_mask (u8x64_load_unaligned (data), splat);
+      if (bmp != -1)
+	return count + count_trailing_zeros (~bmp);
+
+      data += 64;
+      count += 64;
+    }
+  if (count == max_count)
+    return count;
+#if defined(CLIB_HAVE_VEC512_MASK_LOAD_STORE)
+  else
+    {
+      u64 mask = pow2_mask (max_count - count);
+      u64 bmp = u8x64_is_equal_mask (u8x64_mask_load_zero (data, mask), splat);
+      return count + count_trailing_zeros (~bmp);
+    }
+#endif
+#elif defined(CLIB_HAVE_VEC256)
   u8x32 splat = u8x32_splat (first);
   while (count + 31 < max_count)
     {
       u64 bmp;
       bmp = u8x32_msb_mask ((u8x32) (u8x32_load_unaligned (data) == splat));
       if (bmp != 0xffffffff)
-	return max_count;
+	return count + count_trailing_zeros (~bmp);
 
       data += 32;
       count += 32;
     }
+  if (count == max_count)
+    return count;
+#if defined(CLIB_HAVE_VEC256_MASK_LOAD_STORE)
+  else
+    {
+      u32 mask = pow2_mask (max_count - count);
+      u64 bmp = u8x32_msb_mask (u8x32_mask_load_zero (data, mask) == splat);
+      return count + count_trailing_zeros (~bmp);
+    }
+#endif
 #elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK)
   u8x16 splat = u8x16_splat (first);
   while (count + 15 < max_count)
@@ -210,10 +276,7 @@
       u64 bmp;
       bmp = u8x16_msb_mask ((u8x16) (u8x16_load_unaligned (data) == splat));
       if (bmp != 0xffff)
-	{
-	  count += count_trailing_zeros (~bmp);
-	  return count;
-	}
+	return count + count_trailing_zeros (~bmp);
 
       data += 16;
       count += 16;
@@ -235,4 +298,5 @@
     }
   return count;
 }
+
 #endif
diff --git a/src/vppinfra/vector_avx512.h b/src/vppinfra/vector_avx512.h
index a82231a..1a5c252 100644
--- a/src/vppinfra/vector_avx512.h
+++ b/src/vppinfra/vector_avx512.h
@@ -301,6 +301,27 @@
 _ (u64x8, u8, epu64, _mm512, __m512i)
 #undef _
 
+#define _(t, m, e, p, it)                                                     \
+  static_always_inline m t##_is_not_equal_mask (t a, t b)                     \
+  {                                                                           \
+    return p##_cmpneq_##e##_mask ((it) a, (it) b);                            \
+  }
+_ (u8x16, u16, epu8, _mm, __m128i)
+_ (u16x8, u8, epu16, _mm, __m128i)
+_ (u32x4, u8, epu32, _mm, __m128i)
+_ (u64x2, u8, epu64, _mm, __m128i)
+
+_ (u8x32, u32, epu8, _mm256, __m256i)
+_ (u16x16, u16, epu16, _mm256, __m256i)
+_ (u32x8, u8, epu32, _mm256, __m256i)
+_ (u64x4, u8, epu64, _mm256, __m256i)
+
+_ (u8x64, u64, epu8, _mm512, __m512i)
+_ (u16x32, u32, epu16, _mm512, __m512i)
+_ (u32x16, u16, epu32, _mm512, __m512i)
+_ (u64x8, u8, epu64, _mm512, __m512i)
+#undef _
+
 #define _(f, t, fn, it)                                                       \
   static_always_inline t t##_from_##f (f x) { return (t) fn ((it) x); }
 _ (u16x16, u32x16, _mm512_cvtepi16_epi32, __m256i)