blob: cb3d0784051d69ebf5b880c851e01f81b58c86db [file] [log] [blame]
Damjan Marionb47376f2023-03-15 11:42:06 +00001/* SPDX-License-Identifier: Apache-2.0
2 * Copyright(c) 2023 Cisco Systems, Inc.
3 */
4
5#ifndef __crypto_aes_cbc_h__
6#define __crypto_aes_cbc_h__
7
8#include <vppinfra/clib.h>
9#include <vppinfra/vector.h>
10#include <vppinfra/crypto/aes.h>
11
12typedef struct
13{
14 const u8x16 encrypt_key[15];
15 const u8x16 decrypt_key[15];
16} aes_cbc_key_data_t;
17
18static_always_inline void
19clib_aes_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *src, uword len,
20 const u8 *iv, aes_key_size_t ks, u8 *dst)
21{
22 int rounds = AES_KEY_ROUNDS (ks);
23 u8x16 r, *k = (u8x16 *) kd->encrypt_key;
24
25 r = *(u8x16u *) iv;
26
27 for (int i = 0; i < len; i += 16)
28 {
29 int j;
Damjan Marionb47376f2023-03-15 11:42:06 +000030 r = u8x16_xor3 (r, *(u8x16u *) (src + i), k[0]);
31 for (j = 1; j < rounds; j++)
Damjan Marion9caef2a2024-01-08 19:05:40 +000032 r = aes_enc_round_x1 (r, k[j]);
33 r = aes_enc_last_round_x1 (r, k[rounds]);
Damjan Marionb47376f2023-03-15 11:42:06 +000034 *(u8x16u *) (dst + i) = r;
35 }
36}
37
38static_always_inline void
39clib_aes128_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
40 uword len, const u8 *iv, u8 *ciphertext)
41{
42 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_128, ciphertext);
43}
44
45static_always_inline void
46clib_aes192_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
47 uword len, const u8 *iv, u8 *ciphertext)
48{
49 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_192, ciphertext);
50}
51
52static_always_inline void
53clib_aes256_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
54 uword len, const u8 *iv, u8 *ciphertext)
55{
56 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_256, ciphertext);
57}
58
59static_always_inline void __clib_unused
60aes_cbc_dec (const u8x16 *k, u8x16u *src, u8x16u *dst, u8x16u *iv, int count,
61 int rounds)
62{
63 u8x16 r[4], c[4], f;
64
65 f = iv[0];
66 while (count >= 64)
67 {
68 c[0] = r[0] = src[0];
69 c[1] = r[1] = src[1];
70 c[2] = r[2] = src[2];
71 c[3] = r[3] = src[3];
72
73#if __x86_64__
74 r[0] ^= k[0];
75 r[1] ^= k[0];
76 r[2] ^= k[0];
77 r[3] ^= k[0];
78
79 for (int i = 1; i < rounds; i++)
80 {
Damjan Marion9caef2a2024-01-08 19:05:40 +000081 r[0] = aes_dec_round_x1 (r[0], k[i]);
82 r[1] = aes_dec_round_x1 (r[1], k[i]);
83 r[2] = aes_dec_round_x1 (r[2], k[i]);
84 r[3] = aes_dec_round_x1 (r[3], k[i]);
Damjan Marionb47376f2023-03-15 11:42:06 +000085 }
86
Damjan Marion9caef2a2024-01-08 19:05:40 +000087 r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
88 r[1] = aes_dec_last_round_x1 (r[1], k[rounds]);
89 r[2] = aes_dec_last_round_x1 (r[2], k[rounds]);
90 r[3] = aes_dec_last_round_x1 (r[3], k[rounds]);
Damjan Marionb47376f2023-03-15 11:42:06 +000091#else
92 for (int i = 0; i < rounds - 1; i++)
93 {
94 r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
95 r[1] = vaesimcq_u8 (vaesdq_u8 (r[1], k[i]));
96 r[2] = vaesimcq_u8 (vaesdq_u8 (r[2], k[i]));
97 r[3] = vaesimcq_u8 (vaesdq_u8 (r[3], k[i]));
98 }
99 r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
100 r[1] = vaesdq_u8 (r[1], k[rounds - 1]) ^ k[rounds];
101 r[2] = vaesdq_u8 (r[2], k[rounds - 1]) ^ k[rounds];
102 r[3] = vaesdq_u8 (r[3], k[rounds - 1]) ^ k[rounds];
103#endif
104 dst[0] = r[0] ^ f;
105 dst[1] = r[1] ^ c[0];
106 dst[2] = r[2] ^ c[1];
107 dst[3] = r[3] ^ c[2];
108 f = c[3];
109
110 count -= 64;
111 src += 4;
112 dst += 4;
113 }
114
115 while (count > 0)
116 {
117 c[0] = r[0] = src[0];
118#if __x86_64__
119 r[0] ^= k[0];
120 for (int i = 1; i < rounds; i++)
Damjan Marion9caef2a2024-01-08 19:05:40 +0000121 r[0] = aes_dec_round_x1 (r[0], k[i]);
122 r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
Damjan Marionb47376f2023-03-15 11:42:06 +0000123#else
124 c[0] = r[0] = src[0];
125 for (int i = 0; i < rounds - 1; i++)
126 r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
127 r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
128#endif
129 dst[0] = r[0] ^ f;
130 f = c[0];
131
132 count -= 16;
133 src += 1;
134 dst += 1;
135 }
136}
137
138#if __x86_64__
139#if defined(__VAES__) && defined(__AVX512F__)
140
141static_always_inline u8x64
142aes_block_load_x4 (u8 *src[], int i)
143{
144 u8x64 r = {};
145 r = u8x64_insert_u8x16 (r, aes_block_load (src[0] + i), 0);
146 r = u8x64_insert_u8x16 (r, aes_block_load (src[1] + i), 1);
147 r = u8x64_insert_u8x16 (r, aes_block_load (src[2] + i), 2);
148 r = u8x64_insert_u8x16 (r, aes_block_load (src[3] + i), 3);
149 return r;
150}
151
152static_always_inline void
153aes_block_store_x4 (u8 *dst[], int i, u8x64 r)
154{
155 aes_block_store (dst[0] + i, u8x64_extract_u8x16 (r, 0));
156 aes_block_store (dst[1] + i, u8x64_extract_u8x16 (r, 1));
157 aes_block_store (dst[2] + i, u8x64_extract_u8x16 (r, 2));
158 aes_block_store (dst[3] + i, u8x64_extract_u8x16 (r, 3));
159}
160
161static_always_inline u8x64
162aes4_cbc_dec_permute (u8x64 a, u8x64 b)
163{
164 return (u8x64) u64x8_shuffle2 (a, b, 6, 7, 8, 9, 10, 11, 12, 13);
165}
166
167static_always_inline void
168aes4_cbc_dec (const u8x16 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count,
169 aes_key_size_t rounds)
170{
171 u8x64 f, k4, r[4], c[4] = {};
172 __mmask8 m;
173 int i, n_blocks = count >> 4;
174
175 f = u8x64_insert_u8x16 (u8x64_zero (), *iv, 3);
176
177 while (n_blocks >= 16)
178 {
179 k4 = u8x64_splat_u8x16 (k[0]);
180 c[0] = src[0];
181 c[1] = src[1];
182 c[2] = src[2];
183 c[3] = src[3];
184
185 r[0] = c[0] ^ k4;
186 r[1] = c[1] ^ k4;
187 r[2] = c[2] ^ k4;
188 r[3] = c[3] ^ k4;
189
190 for (i = 1; i < rounds; i++)
191 {
192 k4 = u8x64_splat_u8x16 (k[i]);
193 r[0] = aes_dec_round_x4 (r[0], k4);
194 r[1] = aes_dec_round_x4 (r[1], k4);
195 r[2] = aes_dec_round_x4 (r[2], k4);
196 r[3] = aes_dec_round_x4 (r[3], k4);
197 }
198
199 k4 = u8x64_splat_u8x16 (k[i]);
200 r[0] = aes_dec_last_round_x4 (r[0], k4);
201 r[1] = aes_dec_last_round_x4 (r[1], k4);
202 r[2] = aes_dec_last_round_x4 (r[2], k4);
203 r[3] = aes_dec_last_round_x4 (r[3], k4);
204
205 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
206 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
207 dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
208 dst[3] = r[3] ^= aes4_cbc_dec_permute (c[2], c[3]);
209 f = c[3];
210
211 n_blocks -= 16;
212 src += 4;
213 dst += 4;
214 }
215
216 if (n_blocks >= 12)
217 {
218 k4 = u8x64_splat_u8x16 (k[0]);
219 c[0] = src[0];
220 c[1] = src[1];
221 c[2] = src[2];
222
223 r[0] = c[0] ^ k4;
224 r[1] = c[1] ^ k4;
225 r[2] = c[2] ^ k4;
226
227 for (i = 1; i < rounds; i++)
228 {
229 k4 = u8x64_splat_u8x16 (k[i]);
230 r[0] = aes_dec_round_x4 (r[0], k4);
231 r[1] = aes_dec_round_x4 (r[1], k4);
232 r[2] = aes_dec_round_x4 (r[2], k4);
233 }
234
235 k4 = u8x64_splat_u8x16 (k[i]);
236 r[0] = aes_dec_last_round_x4 (r[0], k4);
237 r[1] = aes_dec_last_round_x4 (r[1], k4);
238 r[2] = aes_dec_last_round_x4 (r[2], k4);
239
240 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
241 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
242 dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
243 f = c[2];
244
245 n_blocks -= 12;
246 src += 3;
247 dst += 3;
248 }
249 else if (n_blocks >= 8)
250 {
251 k4 = u8x64_splat_u8x16 (k[0]);
252 c[0] = src[0];
253 c[1] = src[1];
254
255 r[0] = c[0] ^ k4;
256 r[1] = c[1] ^ k4;
257
258 for (i = 1; i < rounds; i++)
259 {
260 k4 = u8x64_splat_u8x16 (k[i]);
261 r[0] = aes_dec_round_x4 (r[0], k4);
262 r[1] = aes_dec_round_x4 (r[1], k4);
263 }
264
265 k4 = u8x64_splat_u8x16 (k[i]);
266 r[0] = aes_dec_last_round_x4 (r[0], k4);
267 r[1] = aes_dec_last_round_x4 (r[1], k4);
268
269 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
270 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
271 f = c[1];
272
273 n_blocks -= 8;
274 src += 2;
275 dst += 2;
276 }
277 else if (n_blocks >= 4)
278 {
279 c[0] = src[0];
280
281 r[0] = c[0] ^ u8x64_splat_u8x16 (k[0]);
282
283 for (i = 1; i < rounds; i++)
284 r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
285
286 r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
287
288 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
289 f = c[0];
290
291 n_blocks -= 4;
292 src += 1;
293 dst += 1;
294 }
295
296 if (n_blocks > 0)
297 {
298 k4 = u8x64_splat_u8x16 (k[0]);
299 m = (1 << (n_blocks * 2)) - 1;
300 c[0] =
301 (u8x64) _mm512_mask_loadu_epi64 ((__m512i) c[0], m, (__m512i *) src);
302 f = aes4_cbc_dec_permute (f, c[0]);
303 r[0] = c[0] ^ k4;
304 for (i = 1; i < rounds; i++)
305 r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
306 r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
307 _mm512_mask_storeu_epi64 ((__m512i *) dst, m, (__m512i) (r[0] ^ f));
308 }
309}
310#elif defined(__VAES__)
311
312static_always_inline u8x32
313aes_block_load_x2 (u8 *src[], int i)
314{
315 u8x32 r = {};
316 r = u8x32_insert_lo (r, aes_block_load (src[0] + i));
317 r = u8x32_insert_hi (r, aes_block_load (src[1] + i));
318 return r;
319}
320
321static_always_inline void
322aes_block_store_x2 (u8 *dst[], int i, u8x32 r)
323{
324 aes_block_store (dst[0] + i, u8x32_extract_lo (r));
325 aes_block_store (dst[1] + i, u8x32_extract_hi (r));
326}
327
328static_always_inline u8x32
329aes2_cbc_dec_permute (u8x32 a, u8x32 b)
330{
331 return (u8x32) u64x4_shuffle2 ((u64x4) a, (u64x4) b, 2, 3, 4, 5);
332}
333
334static_always_inline void
335aes2_cbc_dec (const u8x16 *k, u8x32u *src, u8x32u *dst, u8x16u *iv, int count,
336 aes_key_size_t rounds)
337{
338 u8x32 k2, f = {}, r[4], c[4] = {};
339 int i, n_blocks = count >> 4;
340
341 f = u8x32_insert_hi (f, *iv);
342
343 while (n_blocks >= 8)
344 {
345 k2 = u8x32_splat_u8x16 (k[0]);
346 c[0] = src[0];
347 c[1] = src[1];
348 c[2] = src[2];
349 c[3] = src[3];
350
351 r[0] = c[0] ^ k2;
352 r[1] = c[1] ^ k2;
353 r[2] = c[2] ^ k2;
354 r[3] = c[3] ^ k2;
355
356 for (i = 1; i < rounds; i++)
357 {
358 k2 = u8x32_splat_u8x16 (k[i]);
359 r[0] = aes_dec_round_x2 (r[0], k2);
360 r[1] = aes_dec_round_x2 (r[1], k2);
361 r[2] = aes_dec_round_x2 (r[2], k2);
362 r[3] = aes_dec_round_x2 (r[3], k2);
363 }
364
365 k2 = u8x32_splat_u8x16 (k[i]);
366 r[0] = aes_dec_last_round_x2 (r[0], k2);
367 r[1] = aes_dec_last_round_x2 (r[1], k2);
368 r[2] = aes_dec_last_round_x2 (r[2], k2);
369 r[3] = aes_dec_last_round_x2 (r[3], k2);
370
371 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
372 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
373 dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
374 dst[3] = r[3] ^= aes2_cbc_dec_permute (c[2], c[3]);
375 f = c[3];
376
377 n_blocks -= 8;
378 src += 4;
379 dst += 4;
380 }
381
382 if (n_blocks >= 6)
383 {
384 k2 = u8x32_splat_u8x16 (k[0]);
385 c[0] = src[0];
386 c[1] = src[1];
387 c[2] = src[2];
388
389 r[0] = c[0] ^ k2;
390 r[1] = c[1] ^ k2;
391 r[2] = c[2] ^ k2;
392
393 for (i = 1; i < rounds; i++)
394 {
395 k2 = u8x32_splat_u8x16 (k[i]);
396 r[0] = aes_dec_round_x2 (r[0], k2);
397 r[1] = aes_dec_round_x2 (r[1], k2);
398 r[2] = aes_dec_round_x2 (r[2], k2);
399 }
400
401 k2 = u8x32_splat_u8x16 (k[i]);
402 r[0] = aes_dec_last_round_x2 (r[0], k2);
403 r[1] = aes_dec_last_round_x2 (r[1], k2);
404 r[2] = aes_dec_last_round_x2 (r[2], k2);
405
406 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
407 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
408 dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
409 f = c[2];
410
411 n_blocks -= 6;
412 src += 3;
413 dst += 3;
414 }
415 else if (n_blocks >= 4)
416 {
417 k2 = u8x32_splat_u8x16 (k[0]);
418 c[0] = src[0];
419 c[1] = src[1];
420
421 r[0] = c[0] ^ k2;
422 r[1] = c[1] ^ k2;
423
424 for (i = 1; i < rounds; i++)
425 {
426 k2 = u8x32_splat_u8x16 (k[i]);
427 r[0] = aes_dec_round_x2 (r[0], k2);
428 r[1] = aes_dec_round_x2 (r[1], k2);
429 }
430
431 k2 = u8x32_splat_u8x16 (k[i]);
432 r[0] = aes_dec_last_round_x2 (r[0], k2);
433 r[1] = aes_dec_last_round_x2 (r[1], k2);
434
435 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
436 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
437 f = c[1];
438
439 n_blocks -= 4;
440 src += 2;
441 dst += 2;
442 }
443 else if (n_blocks >= 2)
444 {
445 k2 = u8x32_splat_u8x16 (k[0]);
446 c[0] = src[0];
447 r[0] = c[0] ^ k2;
448
449 for (i = 1; i < rounds; i++)
450 r[0] = aes_dec_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
451
452 r[0] = aes_dec_last_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
453 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
454 f = c[0];
455
456 n_blocks -= 2;
457 src += 1;
458 dst += 1;
459 }
460
461 if (n_blocks > 0)
462 {
463 u8x16 rl = *(u8x16u *) src ^ k[0];
464 for (i = 1; i < rounds; i++)
Damjan Marion9caef2a2024-01-08 19:05:40 +0000465 rl = aes_dec_round_x1 (rl, k[i]);
466 rl = aes_dec_last_round_x1 (rl, k[i]);
hsandidb3213252023-12-11 04:47:11 +0100467 *(u8x16u *) dst = rl ^ u8x32_extract_hi (f);
Damjan Marionb47376f2023-03-15 11:42:06 +0000468 }
469}
470#endif
471#endif
472
473static_always_inline void
474clib_aes_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key,
475 aes_key_size_t ks)
476{
477 u8x16 e[15], d[15];
478 aes_key_expand (e, key, ks);
479 aes_key_enc_to_dec (e, d, ks);
480 for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
481 {
482 ((u8x16 *) kd->decrypt_key)[i] = d[i];
483 ((u8x16 *) kd->encrypt_key)[i] = e[i];
484 }
485}
486
487static_always_inline void
488clib_aes128_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
489{
490 clib_aes_cbc_key_expand (kd, key, AES_KEY_128);
491}
492static_always_inline void
493clib_aes192_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
494{
495 clib_aes_cbc_key_expand (kd, key, AES_KEY_192);
496}
497static_always_inline void
498clib_aes256_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
499{
500 clib_aes_cbc_key_expand (kd, key, AES_KEY_256);
501}
502
503static_always_inline void
504clib_aes_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
505 uword len, const u8 *iv, aes_key_size_t ks,
506 u8 *plaintext)
507{
508 int rounds = AES_KEY_ROUNDS (ks);
509#if defined(__VAES__) && defined(__AVX512F__)
510 aes4_cbc_dec (kd->decrypt_key, (u8x64u *) ciphertext, (u8x64u *) plaintext,
511 (u8x16u *) iv, (int) len, rounds);
512#elif defined(__VAES__)
513 aes2_cbc_dec (kd->decrypt_key, (u8x32u *) ciphertext, (u8x32u *) plaintext,
514 (u8x16u *) iv, (int) len, rounds);
515#else
516 aes_cbc_dec (kd->decrypt_key, (u8x16u *) ciphertext, (u8x16u *) plaintext,
517 (u8x16u *) iv, (int) len, rounds);
518#endif
519}
520
521static_always_inline void
522clib_aes128_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
523 uword len, const u8 *iv, u8 *plaintext)
524{
525 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_128, plaintext);
526}
527
528static_always_inline void
529clib_aes192_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
530 uword len, const u8 *iv, u8 *plaintext)
531{
532 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_192, plaintext);
533}
534
535static_always_inline void
536clib_aes256_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
537 uword len, const u8 *iv, u8 *plaintext)
538{
539 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_256, plaintext);
540}
541
542#endif /* __crypto_aes_cbc_h__ */