blob: 3f0f1f1f7976c038f85cb644e15c98728b13237f [file] [log] [blame]
Damjan Marionb47376f2023-03-15 11:42:06 +00001/* SPDX-License-Identifier: Apache-2.0
2 * Copyright(c) 2023 Cisco Systems, Inc.
3 */
4
5#ifndef __crypto_aes_cbc_h__
6#define __crypto_aes_cbc_h__
7
8#include <vppinfra/clib.h>
9#include <vppinfra/vector.h>
10#include <vppinfra/crypto/aes.h>
11
12typedef struct
13{
14 const u8x16 encrypt_key[15];
15 const u8x16 decrypt_key[15];
16} aes_cbc_key_data_t;
17
18static_always_inline void
19clib_aes_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *src, uword len,
20 const u8 *iv, aes_key_size_t ks, u8 *dst)
21{
22 int rounds = AES_KEY_ROUNDS (ks);
23 u8x16 r, *k = (u8x16 *) kd->encrypt_key;
24
25 r = *(u8x16u *) iv;
26
27 for (int i = 0; i < len; i += 16)
28 {
29 int j;
30#if __x86_64__
31 r = u8x16_xor3 (r, *(u8x16u *) (src + i), k[0]);
32 for (j = 1; j < rounds; j++)
33 r = aes_enc_round (r, k[j]);
34 r = aes_enc_last_round (r, k[rounds]);
35#else
36 r ^= *(u8x16u *) (src + i);
37 for (j = 1; j < rounds - 1; j++)
38 r = vaesmcq_u8 (vaeseq_u8 (r, k[j]));
39 r = vaeseq_u8 (r, k[j]) ^ k[rounds];
40#endif
41 *(u8x16u *) (dst + i) = r;
42 }
43}
44
45static_always_inline void
46clib_aes128_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
47 uword len, const u8 *iv, u8 *ciphertext)
48{
49 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_128, ciphertext);
50}
51
52static_always_inline void
53clib_aes192_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
54 uword len, const u8 *iv, u8 *ciphertext)
55{
56 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_192, ciphertext);
57}
58
59static_always_inline void
60clib_aes256_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
61 uword len, const u8 *iv, u8 *ciphertext)
62{
63 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_256, ciphertext);
64}
65
66static_always_inline void __clib_unused
67aes_cbc_dec (const u8x16 *k, u8x16u *src, u8x16u *dst, u8x16u *iv, int count,
68 int rounds)
69{
70 u8x16 r[4], c[4], f;
71
72 f = iv[0];
73 while (count >= 64)
74 {
75 c[0] = r[0] = src[0];
76 c[1] = r[1] = src[1];
77 c[2] = r[2] = src[2];
78 c[3] = r[3] = src[3];
79
80#if __x86_64__
81 r[0] ^= k[0];
82 r[1] ^= k[0];
83 r[2] ^= k[0];
84 r[3] ^= k[0];
85
86 for (int i = 1; i < rounds; i++)
87 {
88 r[0] = aes_dec_round (r[0], k[i]);
89 r[1] = aes_dec_round (r[1], k[i]);
90 r[2] = aes_dec_round (r[2], k[i]);
91 r[3] = aes_dec_round (r[3], k[i]);
92 }
93
94 r[0] = aes_dec_last_round (r[0], k[rounds]);
95 r[1] = aes_dec_last_round (r[1], k[rounds]);
96 r[2] = aes_dec_last_round (r[2], k[rounds]);
97 r[3] = aes_dec_last_round (r[3], k[rounds]);
98#else
99 for (int i = 0; i < rounds - 1; i++)
100 {
101 r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
102 r[1] = vaesimcq_u8 (vaesdq_u8 (r[1], k[i]));
103 r[2] = vaesimcq_u8 (vaesdq_u8 (r[2], k[i]));
104 r[3] = vaesimcq_u8 (vaesdq_u8 (r[3], k[i]));
105 }
106 r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
107 r[1] = vaesdq_u8 (r[1], k[rounds - 1]) ^ k[rounds];
108 r[2] = vaesdq_u8 (r[2], k[rounds - 1]) ^ k[rounds];
109 r[3] = vaesdq_u8 (r[3], k[rounds - 1]) ^ k[rounds];
110#endif
111 dst[0] = r[0] ^ f;
112 dst[1] = r[1] ^ c[0];
113 dst[2] = r[2] ^ c[1];
114 dst[3] = r[3] ^ c[2];
115 f = c[3];
116
117 count -= 64;
118 src += 4;
119 dst += 4;
120 }
121
122 while (count > 0)
123 {
124 c[0] = r[0] = src[0];
125#if __x86_64__
126 r[0] ^= k[0];
127 for (int i = 1; i < rounds; i++)
128 r[0] = aes_dec_round (r[0], k[i]);
129 r[0] = aes_dec_last_round (r[0], k[rounds]);
130#else
131 c[0] = r[0] = src[0];
132 for (int i = 0; i < rounds - 1; i++)
133 r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
134 r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
135#endif
136 dst[0] = r[0] ^ f;
137 f = c[0];
138
139 count -= 16;
140 src += 1;
141 dst += 1;
142 }
143}
144
145#if __x86_64__
146#if defined(__VAES__) && defined(__AVX512F__)
147
148static_always_inline u8x64
149aes_block_load_x4 (u8 *src[], int i)
150{
151 u8x64 r = {};
152 r = u8x64_insert_u8x16 (r, aes_block_load (src[0] + i), 0);
153 r = u8x64_insert_u8x16 (r, aes_block_load (src[1] + i), 1);
154 r = u8x64_insert_u8x16 (r, aes_block_load (src[2] + i), 2);
155 r = u8x64_insert_u8x16 (r, aes_block_load (src[3] + i), 3);
156 return r;
157}
158
159static_always_inline void
160aes_block_store_x4 (u8 *dst[], int i, u8x64 r)
161{
162 aes_block_store (dst[0] + i, u8x64_extract_u8x16 (r, 0));
163 aes_block_store (dst[1] + i, u8x64_extract_u8x16 (r, 1));
164 aes_block_store (dst[2] + i, u8x64_extract_u8x16 (r, 2));
165 aes_block_store (dst[3] + i, u8x64_extract_u8x16 (r, 3));
166}
167
168static_always_inline u8x64
169aes4_cbc_dec_permute (u8x64 a, u8x64 b)
170{
171 return (u8x64) u64x8_shuffle2 (a, b, 6, 7, 8, 9, 10, 11, 12, 13);
172}
173
174static_always_inline void
175aes4_cbc_dec (const u8x16 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count,
176 aes_key_size_t rounds)
177{
178 u8x64 f, k4, r[4], c[4] = {};
179 __mmask8 m;
180 int i, n_blocks = count >> 4;
181
182 f = u8x64_insert_u8x16 (u8x64_zero (), *iv, 3);
183
184 while (n_blocks >= 16)
185 {
186 k4 = u8x64_splat_u8x16 (k[0]);
187 c[0] = src[0];
188 c[1] = src[1];
189 c[2] = src[2];
190 c[3] = src[3];
191
192 r[0] = c[0] ^ k4;
193 r[1] = c[1] ^ k4;
194 r[2] = c[2] ^ k4;
195 r[3] = c[3] ^ k4;
196
197 for (i = 1; i < rounds; i++)
198 {
199 k4 = u8x64_splat_u8x16 (k[i]);
200 r[0] = aes_dec_round_x4 (r[0], k4);
201 r[1] = aes_dec_round_x4 (r[1], k4);
202 r[2] = aes_dec_round_x4 (r[2], k4);
203 r[3] = aes_dec_round_x4 (r[3], k4);
204 }
205
206 k4 = u8x64_splat_u8x16 (k[i]);
207 r[0] = aes_dec_last_round_x4 (r[0], k4);
208 r[1] = aes_dec_last_round_x4 (r[1], k4);
209 r[2] = aes_dec_last_round_x4 (r[2], k4);
210 r[3] = aes_dec_last_round_x4 (r[3], k4);
211
212 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
213 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
214 dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
215 dst[3] = r[3] ^= aes4_cbc_dec_permute (c[2], c[3]);
216 f = c[3];
217
218 n_blocks -= 16;
219 src += 4;
220 dst += 4;
221 }
222
223 if (n_blocks >= 12)
224 {
225 k4 = u8x64_splat_u8x16 (k[0]);
226 c[0] = src[0];
227 c[1] = src[1];
228 c[2] = src[2];
229
230 r[0] = c[0] ^ k4;
231 r[1] = c[1] ^ k4;
232 r[2] = c[2] ^ k4;
233
234 for (i = 1; i < rounds; i++)
235 {
236 k4 = u8x64_splat_u8x16 (k[i]);
237 r[0] = aes_dec_round_x4 (r[0], k4);
238 r[1] = aes_dec_round_x4 (r[1], k4);
239 r[2] = aes_dec_round_x4 (r[2], k4);
240 }
241
242 k4 = u8x64_splat_u8x16 (k[i]);
243 r[0] = aes_dec_last_round_x4 (r[0], k4);
244 r[1] = aes_dec_last_round_x4 (r[1], k4);
245 r[2] = aes_dec_last_round_x4 (r[2], k4);
246
247 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
248 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
249 dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
250 f = c[2];
251
252 n_blocks -= 12;
253 src += 3;
254 dst += 3;
255 }
256 else if (n_blocks >= 8)
257 {
258 k4 = u8x64_splat_u8x16 (k[0]);
259 c[0] = src[0];
260 c[1] = src[1];
261
262 r[0] = c[0] ^ k4;
263 r[1] = c[1] ^ k4;
264
265 for (i = 1; i < rounds; i++)
266 {
267 k4 = u8x64_splat_u8x16 (k[i]);
268 r[0] = aes_dec_round_x4 (r[0], k4);
269 r[1] = aes_dec_round_x4 (r[1], k4);
270 }
271
272 k4 = u8x64_splat_u8x16 (k[i]);
273 r[0] = aes_dec_last_round_x4 (r[0], k4);
274 r[1] = aes_dec_last_round_x4 (r[1], k4);
275
276 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
277 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
278 f = c[1];
279
280 n_blocks -= 8;
281 src += 2;
282 dst += 2;
283 }
284 else if (n_blocks >= 4)
285 {
286 c[0] = src[0];
287
288 r[0] = c[0] ^ u8x64_splat_u8x16 (k[0]);
289
290 for (i = 1; i < rounds; i++)
291 r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
292
293 r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
294
295 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
296 f = c[0];
297
298 n_blocks -= 4;
299 src += 1;
300 dst += 1;
301 }
302
303 if (n_blocks > 0)
304 {
305 k4 = u8x64_splat_u8x16 (k[0]);
306 m = (1 << (n_blocks * 2)) - 1;
307 c[0] =
308 (u8x64) _mm512_mask_loadu_epi64 ((__m512i) c[0], m, (__m512i *) src);
309 f = aes4_cbc_dec_permute (f, c[0]);
310 r[0] = c[0] ^ k4;
311 for (i = 1; i < rounds; i++)
312 r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
313 r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
314 _mm512_mask_storeu_epi64 ((__m512i *) dst, m, (__m512i) (r[0] ^ f));
315 }
316}
317#elif defined(__VAES__)
318
319static_always_inline u8x32
320aes_block_load_x2 (u8 *src[], int i)
321{
322 u8x32 r = {};
323 r = u8x32_insert_lo (r, aes_block_load (src[0] + i));
324 r = u8x32_insert_hi (r, aes_block_load (src[1] + i));
325 return r;
326}
327
328static_always_inline void
329aes_block_store_x2 (u8 *dst[], int i, u8x32 r)
330{
331 aes_block_store (dst[0] + i, u8x32_extract_lo (r));
332 aes_block_store (dst[1] + i, u8x32_extract_hi (r));
333}
334
335static_always_inline u8x32
336aes2_cbc_dec_permute (u8x32 a, u8x32 b)
337{
338 return (u8x32) u64x4_shuffle2 ((u64x4) a, (u64x4) b, 2, 3, 4, 5);
339}
340
341static_always_inline void
342aes2_cbc_dec (const u8x16 *k, u8x32u *src, u8x32u *dst, u8x16u *iv, int count,
343 aes_key_size_t rounds)
344{
345 u8x32 k2, f = {}, r[4], c[4] = {};
346 int i, n_blocks = count >> 4;
347
348 f = u8x32_insert_hi (f, *iv);
349
350 while (n_blocks >= 8)
351 {
352 k2 = u8x32_splat_u8x16 (k[0]);
353 c[0] = src[0];
354 c[1] = src[1];
355 c[2] = src[2];
356 c[3] = src[3];
357
358 r[0] = c[0] ^ k2;
359 r[1] = c[1] ^ k2;
360 r[2] = c[2] ^ k2;
361 r[3] = c[3] ^ k2;
362
363 for (i = 1; i < rounds; i++)
364 {
365 k2 = u8x32_splat_u8x16 (k[i]);
366 r[0] = aes_dec_round_x2 (r[0], k2);
367 r[1] = aes_dec_round_x2 (r[1], k2);
368 r[2] = aes_dec_round_x2 (r[2], k2);
369 r[3] = aes_dec_round_x2 (r[3], k2);
370 }
371
372 k2 = u8x32_splat_u8x16 (k[i]);
373 r[0] = aes_dec_last_round_x2 (r[0], k2);
374 r[1] = aes_dec_last_round_x2 (r[1], k2);
375 r[2] = aes_dec_last_round_x2 (r[2], k2);
376 r[3] = aes_dec_last_round_x2 (r[3], k2);
377
378 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
379 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
380 dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
381 dst[3] = r[3] ^= aes2_cbc_dec_permute (c[2], c[3]);
382 f = c[3];
383
384 n_blocks -= 8;
385 src += 4;
386 dst += 4;
387 }
388
389 if (n_blocks >= 6)
390 {
391 k2 = u8x32_splat_u8x16 (k[0]);
392 c[0] = src[0];
393 c[1] = src[1];
394 c[2] = src[2];
395
396 r[0] = c[0] ^ k2;
397 r[1] = c[1] ^ k2;
398 r[2] = c[2] ^ k2;
399
400 for (i = 1; i < rounds; i++)
401 {
402 k2 = u8x32_splat_u8x16 (k[i]);
403 r[0] = aes_dec_round_x2 (r[0], k2);
404 r[1] = aes_dec_round_x2 (r[1], k2);
405 r[2] = aes_dec_round_x2 (r[2], k2);
406 }
407
408 k2 = u8x32_splat_u8x16 (k[i]);
409 r[0] = aes_dec_last_round_x2 (r[0], k2);
410 r[1] = aes_dec_last_round_x2 (r[1], k2);
411 r[2] = aes_dec_last_round_x2 (r[2], k2);
412
413 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
414 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
415 dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
416 f = c[2];
417
418 n_blocks -= 6;
419 src += 3;
420 dst += 3;
421 }
422 else if (n_blocks >= 4)
423 {
424 k2 = u8x32_splat_u8x16 (k[0]);
425 c[0] = src[0];
426 c[1] = src[1];
427
428 r[0] = c[0] ^ k2;
429 r[1] = c[1] ^ k2;
430
431 for (i = 1; i < rounds; i++)
432 {
433 k2 = u8x32_splat_u8x16 (k[i]);
434 r[0] = aes_dec_round_x2 (r[0], k2);
435 r[1] = aes_dec_round_x2 (r[1], k2);
436 }
437
438 k2 = u8x32_splat_u8x16 (k[i]);
439 r[0] = aes_dec_last_round_x2 (r[0], k2);
440 r[1] = aes_dec_last_round_x2 (r[1], k2);
441
442 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
443 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
444 f = c[1];
445
446 n_blocks -= 4;
447 src += 2;
448 dst += 2;
449 }
450 else if (n_blocks >= 2)
451 {
452 k2 = u8x32_splat_u8x16 (k[0]);
453 c[0] = src[0];
454 r[0] = c[0] ^ k2;
455
456 for (i = 1; i < rounds; i++)
457 r[0] = aes_dec_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
458
459 r[0] = aes_dec_last_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
460 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
461 f = c[0];
462
463 n_blocks -= 2;
464 src += 1;
465 dst += 1;
466 }
467
468 if (n_blocks > 0)
469 {
470 u8x16 rl = *(u8x16u *) src ^ k[0];
471 for (i = 1; i < rounds; i++)
472 rl = aes_dec_round (rl, k[i]);
473 rl = aes_dec_last_round (rl, k[i]);
hsandidb3213252023-12-11 04:47:11 +0100474 *(u8x16u *) dst = rl ^ u8x32_extract_hi (f);
Damjan Marionb47376f2023-03-15 11:42:06 +0000475 }
476}
477#endif
478#endif
479
480static_always_inline void
481clib_aes_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key,
482 aes_key_size_t ks)
483{
484 u8x16 e[15], d[15];
485 aes_key_expand (e, key, ks);
486 aes_key_enc_to_dec (e, d, ks);
487 for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
488 {
489 ((u8x16 *) kd->decrypt_key)[i] = d[i];
490 ((u8x16 *) kd->encrypt_key)[i] = e[i];
491 }
492}
493
494static_always_inline void
495clib_aes128_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
496{
497 clib_aes_cbc_key_expand (kd, key, AES_KEY_128);
498}
499static_always_inline void
500clib_aes192_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
501{
502 clib_aes_cbc_key_expand (kd, key, AES_KEY_192);
503}
504static_always_inline void
505clib_aes256_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
506{
507 clib_aes_cbc_key_expand (kd, key, AES_KEY_256);
508}
509
510static_always_inline void
511clib_aes_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
512 uword len, const u8 *iv, aes_key_size_t ks,
513 u8 *plaintext)
514{
515 int rounds = AES_KEY_ROUNDS (ks);
516#if defined(__VAES__) && defined(__AVX512F__)
517 aes4_cbc_dec (kd->decrypt_key, (u8x64u *) ciphertext, (u8x64u *) plaintext,
518 (u8x16u *) iv, (int) len, rounds);
519#elif defined(__VAES__)
520 aes2_cbc_dec (kd->decrypt_key, (u8x32u *) ciphertext, (u8x32u *) plaintext,
521 (u8x16u *) iv, (int) len, rounds);
522#else
523 aes_cbc_dec (kd->decrypt_key, (u8x16u *) ciphertext, (u8x16u *) plaintext,
524 (u8x16u *) iv, (int) len, rounds);
525#endif
526}
527
528static_always_inline void
529clib_aes128_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
530 uword len, const u8 *iv, u8 *plaintext)
531{
532 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_128, plaintext);
533}
534
535static_always_inline void
536clib_aes192_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
537 uword len, const u8 *iv, u8 *plaintext)
538{
539 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_192, plaintext);
540}
541
542static_always_inline void
543clib_aes256_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
544 uword len, const u8 *iv, u8 *plaintext)
545{
546 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_256, plaintext);
547}
548
549#endif /* __crypto_aes_cbc_h__ */