Damjan Marion | 9caef2a | 2024-01-08 19:05:40 +0000 | [diff] [blame^] | 1 | /* SPDX-License-Identifier: Apache-2.0 |
| 2 | * Copyright(c) 2024 Cisco Systems, Inc. |
| 3 | */ |
| 4 | |
| 5 | #ifndef __crypto_aes_ctr_h__ |
| 6 | #define __crypto_aes_ctr_h__ |
| 7 | |
| 8 | #include <vppinfra/clib.h> |
| 9 | #include <vppinfra/vector.h> |
| 10 | #include <vppinfra/cache.h> |
| 11 | #include <vppinfra/string.h> |
| 12 | #include <vppinfra/crypto/aes.h> |
| 13 | |
| 14 | typedef struct |
| 15 | { |
| 16 | const aes_expaned_key_t exp_key[AES_KEY_ROUNDS (AES_KEY_256) + 1]; |
| 17 | } aes_ctr_key_data_t; |
| 18 | |
| 19 | typedef struct |
| 20 | { |
| 21 | const aes_expaned_key_t exp_key[AES_KEY_ROUNDS (AES_KEY_256) + 1]; |
| 22 | aes_counter_t ctr; /* counter (reflected) */ |
| 23 | u8 keystream_bytes[N_AES_BYTES]; /* keystream leftovers */ |
| 24 | u32 n_keystream_bytes; /* number of keystream leftovers */ |
| 25 | } aes_ctr_ctx_t; |
| 26 | |
| 27 | static_always_inline aes_counter_t |
| 28 | aes_ctr_one_block (aes_ctr_ctx_t *ctx, aes_counter_t ctr, const u8 *src, |
| 29 | u8 *dst, u32 n_parallel, u32 n_bytes, int rounds, int last) |
| 30 | { |
| 31 | u32 __clib_aligned (N_AES_BYTES) |
| 32 | inc[] = { N_AES_LANES, 0, 0, 0, N_AES_LANES, 0, 0, 0, |
| 33 | N_AES_LANES, 0, 0, 0, N_AES_LANES, 0, 0, 0 }; |
| 34 | const aes_expaned_key_t *k = ctx->exp_key; |
| 35 | const aes_mem_t *sv = (aes_mem_t *) src; |
| 36 | aes_mem_t *dv = (aes_mem_t *) dst; |
| 37 | aes_data_t d[4], t[4]; |
| 38 | u32 r; |
| 39 | |
| 40 | n_bytes -= (n_parallel - 1) * N_AES_BYTES; |
| 41 | |
| 42 | /* AES First Round */ |
| 43 | for (int i = 0; i < n_parallel; i++) |
| 44 | { |
| 45 | #if N_AES_LANES == 4 |
| 46 | t[i] = k[0].x4 ^ (u8x64) aes_reflect ((u8x64) ctr); |
| 47 | #elif N_AES_LANES == 2 |
| 48 | t[i] = k[0].x2 ^ (u8x32) aes_reflect ((u8x32) ctr); |
| 49 | #else |
| 50 | t[i] = k[0].x1 ^ (u8x16) aes_reflect ((u8x16) ctr); |
| 51 | #endif |
| 52 | ctr += *(aes_counter_t *) inc; |
| 53 | } |
| 54 | |
| 55 | /* Load Data */ |
| 56 | for (int i = 0; i < n_parallel - last; i++) |
| 57 | d[i] = sv[i]; |
| 58 | |
| 59 | if (last) |
| 60 | d[n_parallel - 1] = |
| 61 | aes_load_partial ((u8 *) (sv + n_parallel - 1), n_bytes); |
| 62 | |
| 63 | /* AES Intermediate Rounds */ |
| 64 | for (r = 1; r < rounds; r++) |
| 65 | aes_enc_round (t, k + r, n_parallel); |
| 66 | |
| 67 | /* AES Last Round */ |
| 68 | aes_enc_last_round (t, d, k + r, n_parallel); |
| 69 | |
| 70 | /* Store Data */ |
| 71 | for (int i = 0; i < n_parallel - last; i++) |
| 72 | dv[i] = d[i]; |
| 73 | |
| 74 | if (last) |
| 75 | { |
| 76 | aes_store_partial (d[n_parallel - 1], dv + n_parallel - 1, n_bytes); |
| 77 | *(aes_data_t *) ctx->keystream_bytes = t[n_parallel - 1]; |
| 78 | ctx->n_keystream_bytes = N_AES_BYTES - n_bytes; |
| 79 | } |
| 80 | |
| 81 | return ctr; |
| 82 | } |
| 83 | |
| 84 | static_always_inline void |
| 85 | clib_aes_ctr_init (aes_ctr_ctx_t *ctx, const aes_ctr_key_data_t *kd, |
| 86 | const u8 *iv, aes_key_size_t ks) |
| 87 | { |
| 88 | u32x4 ctr = (u32x4) u8x16_reflect (*(u8x16u *) iv); |
| 89 | #if N_AES_LANES == 4 |
| 90 | ctx->ctr = (aes_counter_t) u32x16_splat_u32x4 (ctr) + |
| 91 | (u32x16){ 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0 }; |
| 92 | #elif N_AES_LANES == 2 |
| 93 | ctx->ctr = (aes_counter_t) u32x8_splat_u32x4 (ctr) + |
| 94 | (u32x8){ 0, 0, 0, 0, 1, 0, 0, 0 }; |
| 95 | #else |
| 96 | ctx->ctr = ctr; |
| 97 | #endif |
| 98 | for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++) |
| 99 | ((aes_expaned_key_t *) ctx->exp_key)[i] = kd->exp_key[i]; |
| 100 | ctx->n_keystream_bytes = 0; |
| 101 | } |
| 102 | |
| 103 | static_always_inline void |
| 104 | clib_aes_ctr_transform (aes_ctr_ctx_t *ctx, const u8 *src, u8 *dst, |
| 105 | u32 n_bytes, aes_key_size_t ks) |
| 106 | { |
| 107 | int r = AES_KEY_ROUNDS (ks); |
| 108 | aes_counter_t ctr = ctx->ctr; |
| 109 | |
| 110 | if (ctx->n_keystream_bytes) |
| 111 | { |
| 112 | u8 *ks = ctx->keystream_bytes + N_AES_BYTES - ctx->n_keystream_bytes; |
| 113 | |
| 114 | if (ctx->n_keystream_bytes >= n_bytes) |
| 115 | { |
| 116 | for (int i = 0; i < n_bytes; i++) |
| 117 | dst[i] = src[i] ^ ks[i]; |
| 118 | ctx->n_keystream_bytes -= n_bytes; |
| 119 | return; |
| 120 | } |
| 121 | |
| 122 | for (int i = 0; i < ctx->n_keystream_bytes; i++) |
| 123 | dst++[0] = src++[0] ^ ks[i]; |
| 124 | |
| 125 | n_bytes -= ctx->n_keystream_bytes; |
| 126 | ctx->n_keystream_bytes = 0; |
| 127 | } |
| 128 | |
| 129 | /* main loop */ |
| 130 | for (int n = 4 * N_AES_BYTES; n_bytes >= n; n_bytes -= n, dst += n, src += n) |
| 131 | ctr = aes_ctr_one_block (ctx, ctr, src, dst, 4, n, r, 0); |
| 132 | |
| 133 | if (n_bytes) |
| 134 | { |
| 135 | if (n_bytes > 3 * N_AES_BYTES) |
| 136 | ctr = aes_ctr_one_block (ctx, ctr, src, dst, 4, n_bytes, r, 1); |
| 137 | else if (n_bytes > 2 * N_AES_BYTES) |
| 138 | ctr = aes_ctr_one_block (ctx, ctr, src, dst, 3, n_bytes, r, 1); |
| 139 | else if (n_bytes > N_AES_BYTES) |
| 140 | ctr = aes_ctr_one_block (ctx, ctr, src, dst, 2, n_bytes, r, 1); |
| 141 | else |
| 142 | ctr = aes_ctr_one_block (ctx, ctr, src, dst, 1, n_bytes, r, 1); |
| 143 | } |
| 144 | else |
| 145 | ctx->n_keystream_bytes = 0; |
| 146 | |
| 147 | ctx->ctr = ctr; |
| 148 | } |
| 149 | |
| 150 | static_always_inline void |
| 151 | clib_aes_ctr_key_expand (aes_ctr_key_data_t *kd, const u8 *key, |
| 152 | aes_key_size_t ks) |
| 153 | { |
| 154 | u8x16 ek[AES_KEY_ROUNDS (AES_KEY_256) + 1]; |
| 155 | aes_expaned_key_t *k = (aes_expaned_key_t *) kd->exp_key; |
| 156 | |
| 157 | /* expand AES key */ |
| 158 | aes_key_expand (ek, key, ks); |
| 159 | for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++) |
| 160 | k[i].lanes[0] = k[i].lanes[1] = k[i].lanes[2] = k[i].lanes[3] = ek[i]; |
| 161 | } |
| 162 | |
| 163 | static_always_inline void |
| 164 | clib_aes128_ctr (const aes_ctr_key_data_t *kd, const u8 *src, u32 n_bytes, |
| 165 | const u8 *iv, u8 *dst) |
| 166 | { |
| 167 | aes_ctr_ctx_t ctx; |
| 168 | clib_aes_ctr_init (&ctx, kd, iv, AES_KEY_128); |
| 169 | clib_aes_ctr_transform (&ctx, src, dst, n_bytes, AES_KEY_128); |
| 170 | } |
| 171 | |
| 172 | static_always_inline void |
| 173 | clib_aes192_ctr (const aes_ctr_key_data_t *kd, const u8 *src, u32 n_bytes, |
| 174 | const u8 *iv, u8 *dst) |
| 175 | { |
| 176 | aes_ctr_ctx_t ctx; |
| 177 | clib_aes_ctr_init (&ctx, kd, iv, AES_KEY_192); |
| 178 | clib_aes_ctr_transform (&ctx, src, dst, n_bytes, AES_KEY_192); |
| 179 | } |
| 180 | |
| 181 | static_always_inline void |
| 182 | clib_aes256_ctr (const aes_ctr_key_data_t *kd, const u8 *src, u32 n_bytes, |
| 183 | const u8 *iv, u8 *dst) |
| 184 | { |
| 185 | aes_ctr_ctx_t ctx; |
| 186 | clib_aes_ctr_init (&ctx, kd, iv, AES_KEY_256); |
| 187 | clib_aes_ctr_transform (&ctx, src, dst, n_bytes, AES_KEY_256); |
| 188 | } |
| 189 | |
| 190 | #endif /* __crypto_aes_ctr_h__ */ |