vppinfra: small improvement and polishing of AES GCM code
Type: improvement
Change-Id: Ie9661792ec68d4ea3c62ee9eb31b455d3b2b0a42
Signed-off-by: Damjan Marion <damarion@cisco.com>
diff --git a/src/vppinfra/crypto/aes_gcm.h b/src/vppinfra/crypto/aes_gcm.h
index 8a5f76c..3d1b220 100644
--- a/src/vppinfra/crypto/aes_gcm.h
+++ b/src/vppinfra/crypto/aes_gcm.h
@@ -103,9 +103,15 @@
aes_gcm_counter_t Y;
/* ghash */
- ghash_data_t gd;
+ ghash_ctx_t gd;
} aes_gcm_ctx_t;
+static_always_inline u8x16
+aes_gcm_final_block (aes_gcm_ctx_t *ctx)
+{
+ return (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
+}
+
static_always_inline void
aes_gcm_ghash_mul_first (aes_gcm_ctx_t *ctx, aes_data_t data, u32 n_lanes)
{
@@ -137,19 +143,18 @@
}
static_always_inline void
-aes_gcm_ghash_mul_bit_len (aes_gcm_ctx_t *ctx)
+aes_gcm_ghash_mul_final_block (aes_gcm_ctx_t *ctx)
{
- u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
#if N_LANES == 4
u8x64 h = u8x64_insert_u8x16 (u8x64_zero (), ctx->Hi[NUM_HI - 1], 0);
- u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), r, 0);
+ u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), aes_gcm_final_block (ctx), 0);
ghash4_mul_next (&ctx->gd, r4, h);
#elif N_LANES == 2
u8x32 h = u8x32_insert_lo (u8x32_zero (), ctx->Hi[NUM_HI - 1]);
- u8x32 r2 = u8x32_insert_lo (u8x32_zero (), r);
+ u8x32 r2 = u8x32_insert_lo (u8x32_zero (), aes_gcm_final_block (ctx));
ghash2_mul_next (&ctx->gd, r2, h);
#else
- ghash_mul_next (&ctx->gd, r, ctx->Hi[NUM_HI - 1]);
+ ghash_mul_next (&ctx->gd, aes_gcm_final_block (ctx), ctx->Hi[NUM_HI - 1]);
#endif
}
@@ -178,7 +183,7 @@
aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES + 1);
for (i = 1; i < 8; i++)
aes_gcm_ghash_mul_next (ctx, d[i]);
- aes_gcm_ghash_mul_bit_len (ctx);
+ aes_gcm_ghash_mul_final_block (ctx);
aes_gcm_ghash_reduce (ctx);
aes_gcm_ghash_reduce2 (ctx);
aes_gcm_ghash_final (ctx);
@@ -243,16 +248,14 @@
}
if (ctx->operation == AES_GCM_OP_GMAC)
- aes_gcm_ghash_mul_bit_len (ctx);
+ aes_gcm_ghash_mul_final_block (ctx);
aes_gcm_ghash_reduce (ctx);
aes_gcm_ghash_reduce2 (ctx);
aes_gcm_ghash_final (ctx);
}
else if (ctx->operation == AES_GCM_OP_GMAC)
- {
- u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
- ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]);
- }
+ ctx->T =
+ ghash_mul (aes_gcm_final_block (ctx) ^ ctx->T, ctx->Hi[NUM_HI - 1]);
done:
/* encrypt counter 0 E(Y0, k) */
@@ -267,6 +270,11 @@
const aes_gcm_expaned_key_t Ke0 = ctx->Ke[0];
uword i = 0;
+ /* As counter is stored in network byte order for performance reasons we
+ are incrementing least significant byte only except in case where we
+ overlow. As we are processing four 128, 256 or 512-blocks in parallel
+ except the last round, overflow can happen only when n_blocks == 4 */
+
#if N_LANES == 4
const u32x16 ctr_inv_4444 = { 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24,
0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24 };
@@ -275,15 +283,10 @@
4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0,
};
- /* As counter is stored in network byte order for performance reasons we
- are incrementing least significant byte only except in case where we
- overlow. As we are processing four 512-blocks in parallel except the
- last round, overflow can happen only when n == 4 */
-
if (n_blocks == 4)
for (; i < 2; i++)
{
- r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
+ r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
ctx->Y += ctr_inv_4444;
}
@@ -293,7 +296,7 @@
for (; i < n_blocks; i++)
{
- r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
+ r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
Yr += ctr_4444;
ctx->Y = (u32x16) aes_gcm_reflect ((u8x64) Yr);
}
@@ -302,7 +305,7 @@
{
for (; i < n_blocks; i++)
{
- r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
+ r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
ctx->Y += ctr_inv_4444;
}
}
@@ -311,15 +314,10 @@
const u32x8 ctr_inv_22 = { 0, 0, 0, 2 << 24, 0, 0, 0, 2 << 24 };
const u32x8 ctr_22 = { 2, 0, 0, 0, 2, 0, 0, 0 };
- /* As counter is stored in network byte order for performance reasons we
- are incrementing least significant byte only except in case where we
- overlow. As we are processing four 512-blocks in parallel except the
- last round, overflow can happen only when n == 4 */
-
if (n_blocks == 4)
for (; i < 2; i++)
{
- r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
+ r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
ctx->Y += ctr_inv_22;
}
@@ -329,7 +327,7 @@
for (; i < n_blocks; i++)
{
- r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
+ r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
Yr += ctr_22;
ctx->Y = (u32x8) aes_gcm_reflect ((u8x32) Yr);
}
@@ -338,7 +336,7 @@
{
for (; i < n_blocks; i++)
{
- r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
+ r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
ctx->Y += ctr_inv_22;
}
}
@@ -350,20 +348,20 @@
{
for (; i < n_blocks; i++)
{
- r[i] = Ke0.x1 ^ (u8x16) ctx->Y;
+ r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
ctx->Y += ctr_inv_1;
}
ctx->counter += n_blocks;
}
else
{
- r[i++] = Ke0.x1 ^ (u8x16) ctx->Y;
+ r[i++] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
ctx->Y += ctr_inv_1;
ctx->counter += 1;
for (; i < n_blocks; i++)
{
- r[i] = Ke0.x1 ^ (u8x16) ctx->Y;
+ r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
ctx->counter++;
ctx->Y[3] = clib_host_to_net_u32 (ctx->counter);
}
@@ -510,8 +508,7 @@
}
static_always_inline void
-aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst,
- int with_ghash)
+aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst)
{
const aes_gcm_expaned_key_t *k = ctx->Ke;
const aes_mem_t *sv = (aes_mem_t *) src;
@@ -680,7 +677,7 @@
aes_gcm_enc_ctr0_round (ctx, 8);
aes_gcm_enc_ctr0_round (ctx, 9);
- aes_gcm_ghash_mul_bit_len (ctx);
+ aes_gcm_ghash_mul_final_block (ctx);
aes_gcm_ghash_reduce (ctx);
for (i = 10; i < ctx->rounds; i++)
@@ -731,6 +728,7 @@
}
return;
}
+
aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 0);
/* next */
@@ -739,7 +737,7 @@
src += 4 * N;
for (; n_left >= 8 * N; n_left -= 8 * N, src += 8 * N, dst += 8 * N)
- aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1);
+ aes_gcm_calc_double (ctx, d, src, dst);
if (n_left >= 4 * N)
{
@@ -785,8 +783,11 @@
aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left)
{
aes_data_t d[4] = {};
+ ghash_ctx_t gd;
+
+ /* main encryption loop */
for (; n_left >= 8 * N; n_left -= 8 * N, dst += 8 * N, src += 8 * N)
- aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1);
+ aes_gcm_calc_double (ctx, d, src, dst);
if (n_left >= 4 * N)
{
@@ -798,27 +799,48 @@
src += N * 4;
}
- if (n_left == 0)
- goto done;
+ if (n_left)
+ {
+ ctx->last = 1;
- ctx->last = 1;
+ if (n_left > 3 * N)
+ aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
+ else if (n_left > 2 * N)
+ aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
+ else if (n_left > N)
+ aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
+ else
+ aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
+ }
- if (n_left > 3 * N)
- aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
- else if (n_left > 2 * N)
- aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
- else if (n_left > N)
- aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
- else
- aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
+ /* interleaved counter 0 encryption E(Y0, k) and ghash of final GCM
+ * (bit length) block */
- u8x16 r;
-done:
- r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
- ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]);
+ aes_gcm_enc_ctr0_round (ctx, 0);
+ aes_gcm_enc_ctr0_round (ctx, 1);
- /* encrypt counter 0 E(Y0, k) */
- for (int i = 0; i < ctx->rounds + 1; i += 1)
+ ghash_mul_first (&gd, aes_gcm_final_block (ctx) ^ ctx->T,
+ ctx->Hi[NUM_HI - 1]);
+
+ aes_gcm_enc_ctr0_round (ctx, 2);
+ aes_gcm_enc_ctr0_round (ctx, 3);
+
+ ghash_reduce (&gd);
+
+ aes_gcm_enc_ctr0_round (ctx, 4);
+ aes_gcm_enc_ctr0_round (ctx, 5);
+
+ ghash_reduce2 (&gd);
+
+ aes_gcm_enc_ctr0_round (ctx, 6);
+ aes_gcm_enc_ctr0_round (ctx, 7);
+
+ ctx->T = ghash_final (&gd);
+
+ aes_gcm_enc_ctr0_round (ctx, 8);
+ aes_gcm_enc_ctr0_round (ctx, 9);
+
+ for (int i = 10; i < ctx->rounds + 1; i += 1)
aes_gcm_enc_ctr0_round (ctx, i);
}
@@ -835,6 +857,7 @@
.operation = op,
.data_bytes = data_bytes,
.aad_bytes = aad_bytes,
+ .Ke = kd->Ke,
.Hi = kd->Hi },
*ctx = &_ctx;
@@ -843,7 +866,7 @@
Y0[2] = *(u32u *) (ivp + 8);
Y0[3] = 1 << 24;
ctx->EY0 = (u8x16) Y0;
- ctx->Ke = kd->Ke;
+
#if N_LANES == 4
ctx->Y = u32x16_splat_u32x4 (Y0) + (u32x16){
0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24,
@@ -858,8 +881,6 @@
/* calculate ghash for AAD */
aes_gcm_ghash (ctx, addt, aad_bytes);
- clib_prefetch_load (tag);
-
/* ghash and encrypt/edcrypt */
if (op == AES_GCM_OP_ENCRYPT)
aes_gcm_enc (ctx, src, dst, data_bytes);
diff --git a/src/vppinfra/crypto/ghash.h b/src/vppinfra/crypto/ghash.h
index bae8bad..66e3f6a 100644
--- a/src/vppinfra/crypto/ghash.h
+++ b/src/vppinfra/crypto/ghash.h
@@ -89,7 +89,7 @@
* u8x16 Hi[4];
* ghash_precompute (H, Hi, 4);
*
- * ghash_data_t _gd, *gd = &_gd;
+ * ghash_ctx_t _gd, *gd = &_gd;
* ghash_mul_first (gd, GH ^ b0, Hi[3]);
* ghash_mul_next (gd, b1, Hi[2]);
* ghash_mul_next (gd, b2, Hi[1]);
@@ -154,7 +154,7 @@
u8x32 hi2, lo2, mid2, tmp_lo2, tmp_hi2;
u8x64 hi4, lo4, mid4, tmp_lo4, tmp_hi4;
int pending;
-} ghash_data_t;
+} ghash_ctx_t;
static const u8x16 ghash_poly = {
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
@@ -167,7 +167,7 @@
};
static_always_inline void
-ghash_mul_first (ghash_data_t * gd, u8x16 a, u8x16 b)
+ghash_mul_first (ghash_ctx_t *gd, u8x16 a, u8x16 b)
{
/* a1 * b1 */
gd->hi = gmul_hi_hi (a, b);
@@ -182,7 +182,7 @@
}
static_always_inline void
-ghash_mul_next (ghash_data_t * gd, u8x16 a, u8x16 b)
+ghash_mul_next (ghash_ctx_t *gd, u8x16 a, u8x16 b)
{
/* a1 * b1 */
u8x16 hi = gmul_hi_hi (a, b);
@@ -211,7 +211,7 @@
}
static_always_inline void
-ghash_reduce (ghash_data_t * gd)
+ghash_reduce (ghash_ctx_t *gd)
{
u8x16 r;
@@ -236,14 +236,14 @@
}
static_always_inline void
-ghash_reduce2 (ghash_data_t * gd)
+ghash_reduce2 (ghash_ctx_t *gd)
{
gd->tmp_lo = gmul_lo_lo (ghash_poly2, gd->lo);
gd->tmp_hi = gmul_lo_hi (ghash_poly2, gd->lo);
}
static_always_inline u8x16
-ghash_final (ghash_data_t * gd)
+ghash_final (ghash_ctx_t *gd)
{
return u8x16_xor3 (gd->hi, u8x16_word_shift_right (gd->tmp_lo, 4),
u8x16_word_shift_left (gd->tmp_hi, 4));
@@ -252,7 +252,7 @@
static_always_inline u8x16
ghash_mul (u8x16 a, u8x16 b)
{
- ghash_data_t _gd, *gd = &_gd;
+ ghash_ctx_t _gd, *gd = &_gd;
ghash_mul_first (gd, a, b);
ghash_reduce (gd);
ghash_reduce2 (gd);
@@ -297,7 +297,7 @@
}
static_always_inline void
-ghash4_mul_first (ghash_data_t *gd, u8x64 a, u8x64 b)
+ghash4_mul_first (ghash_ctx_t *gd, u8x64 a, u8x64 b)
{
gd->hi4 = gmul4_hi_hi (a, b);
gd->lo4 = gmul4_lo_lo (a, b);
@@ -306,7 +306,7 @@
}
static_always_inline void
-ghash4_mul_next (ghash_data_t *gd, u8x64 a, u8x64 b)
+ghash4_mul_next (ghash_ctx_t *gd, u8x64 a, u8x64 b)
{
u8x64 hi = gmul4_hi_hi (a, b);
u8x64 lo = gmul4_lo_lo (a, b);
@@ -329,7 +329,7 @@
}
static_always_inline void
-ghash4_reduce (ghash_data_t *gd)
+ghash4_reduce (ghash_ctx_t *gd)
{
u8x64 r;
@@ -356,14 +356,14 @@
}
static_always_inline void
-ghash4_reduce2 (ghash_data_t *gd)
+ghash4_reduce2 (ghash_ctx_t *gd)
{
gd->tmp_lo4 = gmul4_lo_lo (ghash4_poly2, gd->lo4);
gd->tmp_hi4 = gmul4_lo_hi (ghash4_poly2, gd->lo4);
}
static_always_inline u8x16
-ghash4_final (ghash_data_t *gd)
+ghash4_final (ghash_ctx_t *gd)
{
u8x64 r;
u8x32 t;
@@ -410,7 +410,7 @@
}
static_always_inline void
-ghash2_mul_first (ghash_data_t *gd, u8x32 a, u8x32 b)
+ghash2_mul_first (ghash_ctx_t *gd, u8x32 a, u8x32 b)
{
gd->hi2 = gmul2_hi_hi (a, b);
gd->lo2 = gmul2_lo_lo (a, b);
@@ -419,7 +419,7 @@
}
static_always_inline void
-ghash2_mul_next (ghash_data_t *gd, u8x32 a, u8x32 b)
+ghash2_mul_next (ghash_ctx_t *gd, u8x32 a, u8x32 b)
{
u8x32 hi = gmul2_hi_hi (a, b);
u8x32 lo = gmul2_lo_lo (a, b);
@@ -442,7 +442,7 @@
}
static_always_inline void
-ghash2_reduce (ghash_data_t *gd)
+ghash2_reduce (ghash_ctx_t *gd)
{
u8x32 r;
@@ -469,14 +469,14 @@
}
static_always_inline void
-ghash2_reduce2 (ghash_data_t *gd)
+ghash2_reduce2 (ghash_ctx_t *gd)
{
gd->tmp_lo2 = gmul2_lo_lo (ghash2_poly2, gd->lo2);
gd->tmp_hi2 = gmul2_lo_hi (ghash2_poly2, gd->lo2);
}
static_always_inline u8x16
-ghash2_final (ghash_data_t *gd)
+ghash2_final (ghash_ctx_t *gd)
{
u8x32 r;