wireguard: add processing of received cookie messages

Type: feature

Currently, if a handshake message is sent and a cookie message is
received in reply, the cookie message will be ignored. Thus, further
handshake messages will not have valid mac2 and handshake will not be
able to be completed.

With this change, process received cookie messages to be able to
calculate mac2 for further handshake messages sent. Cover this with
tests.

Signed-off-by: Alexander Chernavin <achernavin@netgate.com>
Change-Id: I6d51459778b7145be7077badec479b2aa85960b9
diff --git a/src/plugins/wireguard/CMakeLists.txt b/src/plugins/wireguard/CMakeLists.txt
index 6dddc67..31f09f1 100644
--- a/src/plugins/wireguard/CMakeLists.txt
+++ b/src/plugins/wireguard/CMakeLists.txt
@@ -33,8 +33,11 @@
   wireguard_input.c
   wireguard_output_tun.c
   wireguard_handoff.c
+  wireguard_hchacha20.h
   wireguard_key.c
   wireguard_key.h
+  wireguard_chachapoly.c
+  wireguard_chachapoly.h
   wireguard_cli.c
   wireguard_messages.h
   wireguard_noise.c
diff --git a/src/plugins/wireguard/wireguard.c b/src/plugins/wireguard/wireguard.c
index 926da2c..5d73638 100644
--- a/src/plugins/wireguard/wireguard.c
+++ b/src/plugins/wireguard/wireguard.c
@@ -59,6 +59,13 @@
     vnet_crypto_register_post_node (vm, "wg6-input-post-node");
 }
 
+void
+wg_secure_zero_memory (void *v, size_t n)
+{
+  static void *(*const volatile memset_v) (void *, int, size_t) = &memset;
+  memset_v (v, 0, n);
+}
+
 static clib_error_t *
 wg_init (vlib_main_t * vm)
 {
diff --git a/src/plugins/wireguard/wireguard.h b/src/plugins/wireguard/wireguard.h
index ba96864..3a6248b 100644
--- a/src/plugins/wireguard/wireguard.h
+++ b/src/plugins/wireguard/wireguard.h
@@ -117,6 +117,8 @@
 void wg_feature_init (wg_main_t * wmp);
 void wg_set_async_mode (u32 is_enabled);
 
+void wg_secure_zero_memory (void *v, size_t n);
+
 #endif /* __included_wg_h__ */
 
 /*
diff --git a/src/plugins/wireguard/wireguard_chachapoly.c b/src/plugins/wireguard/wireguard_chachapoly.c
new file mode 100644
index 0000000..961b43f
--- /dev/null
+++ b/src/plugins/wireguard/wireguard_chachapoly.c
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) 2022 Rubicon Communications, LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at:
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <wireguard/wireguard.h>
+#include <wireguard/wireguard_chachapoly.h>
+#include <wireguard/wireguard_hchacha20.h>
+
+bool
+wg_chacha20poly1305_calc (vlib_main_t *vm, u8 *src, u32 src_len, u8 *dst,
+			  u8 *aad, u32 aad_len, u64 nonce,
+			  vnet_crypto_op_id_t op_id,
+			  vnet_crypto_key_index_t key_index)
+{
+  vnet_crypto_op_t _op, *op = &_op;
+  u8 iv[12];
+  u8 tag_[NOISE_AUTHTAG_LEN] = {};
+  u8 src_[] = {};
+
+  clib_memset (iv, 0, 12);
+  clib_memcpy (iv + 4, &nonce, sizeof (nonce));
+
+  vnet_crypto_op_init (op, op_id);
+
+  op->tag_len = NOISE_AUTHTAG_LEN;
+  if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC)
+    {
+      op->tag = src + src_len - NOISE_AUTHTAG_LEN;
+      src_len -= NOISE_AUTHTAG_LEN;
+      op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK;
+    }
+  else
+    op->tag = tag_;
+
+  op->src = !src ? src_ : src;
+  op->len = src_len;
+
+  op->dst = dst;
+  op->key_index = key_index;
+  op->aad = aad;
+  op->aad_len = aad_len;
+  op->iv = iv;
+
+  vnet_crypto_process_ops (vm, op, 1);
+  if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC)
+    {
+      clib_memcpy (dst + src_len, op->tag, NOISE_AUTHTAG_LEN);
+    }
+
+  return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED);
+}
+
+bool
+wg_xchacha20poly1305_decrypt (vlib_main_t *vm, u8 *src, u32 src_len, u8 *dst,
+			      u8 *aad, u32 aad_len,
+			      u8 nonce[XCHACHA20POLY1305_NONCE_SIZE],
+			      u8 key[CHACHA20POLY1305_KEY_SIZE])
+{
+  int ret, i;
+  u32 derived_key[CHACHA20POLY1305_KEY_SIZE / sizeof (u32)];
+  u64 h_nonce;
+
+  clib_memcpy (&h_nonce, nonce + 16, sizeof (h_nonce));
+  h_nonce = le64toh (h_nonce);
+  hchacha20 (derived_key, nonce, key);
+
+  for (i = 0; i < (sizeof (derived_key) / sizeof (derived_key[0])); i++)
+    (derived_key[i]) = htole32 ((derived_key[i]));
+
+  uint32_t key_idx;
+
+  key_idx =
+    vnet_crypto_key_add (vm, VNET_CRYPTO_ALG_CHACHA20_POLY1305,
+			 (uint8_t *) derived_key, CHACHA20POLY1305_KEY_SIZE);
+
+  ret =
+    wg_chacha20poly1305_calc (vm, src, src_len, dst, aad, aad_len, h_nonce,
+			      VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx);
+
+  vnet_crypto_key_del (vm, key_idx);
+  wg_secure_zero_memory (derived_key, CHACHA20POLY1305_KEY_SIZE);
+
+  return ret;
+}
+
+/*
+ * fd.io coding-style-patch-verification: ON
+ *
+ * Local Variables:
+ * eval: (c-set-style "gnu")
+ * End:
+ */
diff --git a/src/plugins/wireguard/wireguard_chachapoly.h b/src/plugins/wireguard/wireguard_chachapoly.h
new file mode 100644
index 0000000..803774c
--- /dev/null
+++ b/src/plugins/wireguard/wireguard_chachapoly.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2022 Rubicon Communications, LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at:
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __included_wg_chachapoly_h__
+#define __included_wg_chachapoly_h__
+
+#include <vlib/vlib.h>
+#include <vnet/crypto/crypto.h>
+
+#define XCHACHA20POLY1305_NONCE_SIZE 24
+#define CHACHA20POLY1305_KEY_SIZE    32
+
+bool wg_chacha20poly1305_calc (vlib_main_t *vm, u8 *src, u32 src_len, u8 *dst,
+			       u8 *aad, u32 aad_len, u64 nonce,
+			       vnet_crypto_op_id_t op_id,
+			       vnet_crypto_key_index_t key_index);
+
+bool wg_xchacha20poly1305_decrypt (vlib_main_t *vm, u8 *src, u32 src_len,
+				   u8 *dst, u8 *aad, u32 aad_len,
+				   u8 nonce[XCHACHA20POLY1305_NONCE_SIZE],
+				   u8 key[CHACHA20POLY1305_KEY_SIZE]);
+
+#endif /* __included_wg_chachapoly_h__ */
+
+/*
+ * fd.io coding-style-patch-verification: ON
+ *
+ * Local Variables:
+ * eval: (c-set-style "gnu")
+ * End:
+ */
diff --git a/src/plugins/wireguard/wireguard_cookie.c b/src/plugins/wireguard/wireguard_cookie.c
index c4279b7..47e8784 100644
--- a/src/plugins/wireguard/wireguard_cookie.c
+++ b/src/plugins/wireguard/wireguard_cookie.c
@@ -20,6 +20,7 @@
 #include <vlib/vlib.h>
 
 #include <wireguard/wireguard_cookie.h>
+#include <wireguard/wireguard_chachapoly.h>
 #include <wireguard/wireguard.h>
 
 static void cookie_precompute_key (uint8_t *,
@@ -57,6 +58,32 @@
     }
 }
 
+bool
+cookie_maker_consume_payload (vlib_main_t *vm, cookie_maker_t *cp,
+			      uint8_t nonce[COOKIE_NONCE_SIZE],
+			      uint8_t ecookie[COOKIE_ENCRYPTED_SIZE])
+{
+  uint8_t cookie[COOKIE_COOKIE_SIZE];
+
+  if (cp->cp_mac1_valid == 0)
+    {
+      return false;
+    }
+
+  if (!wg_xchacha20poly1305_decrypt (vm, ecookie, COOKIE_ENCRYPTED_SIZE,
+				     cookie, cp->cp_mac1_last, COOKIE_MAC_SIZE,
+				     nonce, cp->cp_cookie_key))
+    {
+      return false;
+    }
+
+  clib_memcpy (cp->cp_cookie, cookie, COOKIE_COOKIE_SIZE);
+  cp->cp_birthdate = vlib_time_now (vm);
+  cp->cp_mac1_valid = 0;
+
+  return true;
+}
+
 void
 cookie_maker_mac (cookie_maker_t * cp, message_macs_t * cm, void *buf,
 		  size_t len)
diff --git a/src/plugins/wireguard/wireguard_cookie.h b/src/plugins/wireguard/wireguard_cookie.h
index 6ef418f..e4bea90 100644
--- a/src/plugins/wireguard/wireguard_cookie.h
+++ b/src/plugins/wireguard/wireguard_cookie.h
@@ -82,6 +82,9 @@
 
 void cookie_maker_init (cookie_maker_t *, const uint8_t[COOKIE_INPUT_SIZE]);
 void cookie_checker_update (cookie_checker_t *, uint8_t[COOKIE_INPUT_SIZE]);
+bool cookie_maker_consume_payload (vlib_main_t *vm, cookie_maker_t *cp,
+				   uint8_t nonce[COOKIE_NONCE_SIZE],
+				   uint8_t ecookie[COOKIE_ENCRYPTED_SIZE]);
 void cookie_maker_mac (cookie_maker_t *, message_macs_t *, void *, size_t);
 enum cookie_mac_state
 cookie_checker_validate_macs (vlib_main_t *vm, cookie_checker_t *,
diff --git a/src/plugins/wireguard/wireguard_hchacha20.h b/src/plugins/wireguard/wireguard_hchacha20.h
new file mode 100644
index 0000000..a2d1396
--- /dev/null
+++ b/src/plugins/wireguard/wireguard_hchacha20.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2022 Rubicon Communications, LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at:
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * chacha-merged.c version 20080118
+ * D. J. Bernstein
+ * Public domain.
+ */
+
+#ifndef __included_wg_hchacha20_h__
+#define __included_wg_hchacha20_h__
+
+#include <vlib/vlib.h>
+
+/* clang-format off */
+#define U32C(v) (v##U)
+#define U32V(v) ((u32)(v) & U32C(0xFFFFFFFF))
+
+#define ROTL32(v, n) \
+  (U32V((v) << (n)) | ((v) >> (32 - (n))))
+
+#define U8TO32_LITTLE(p) \
+  (((u32)((p)[0])      ) | \
+   ((u32)((p)[1]) <<  8) | \
+   ((u32)((p)[2]) << 16) | \
+   ((u32)((p)[3]) << 24))
+
+#define ROTATE(v,c) (ROTL32(v,c))
+#define XOR(v,w) ((v) ^ (w))
+#define PLUS(v,w) (U32V((v) + (w)))
+
+#define QUARTERROUND(a,b,c,d) \
+  a = PLUS(a,b); d = ROTATE(XOR(d,a),16); \
+  c = PLUS(c,d); b = ROTATE(XOR(b,c),12); \
+  a = PLUS(a,b); d = ROTATE(XOR(d,a), 8); \
+  c = PLUS(c,d); b = ROTATE(XOR(b,c), 7);
+/* clang-format on */
+
+static const char sigma[16] = "expand 32-byte k";
+
+static inline void
+hchacha20 (u32 derived_key[8], const u8 nonce[16], const u8 key[32])
+{
+  int i;
+  u32 x[] = { U8TO32_LITTLE (sigma + 0), U8TO32_LITTLE (sigma + 4),
+	      U8TO32_LITTLE (sigma + 8), U8TO32_LITTLE (sigma + 12),
+	      U8TO32_LITTLE (key + 0),	 U8TO32_LITTLE (key + 4),
+	      U8TO32_LITTLE (key + 8),	 U8TO32_LITTLE (key + 12),
+	      U8TO32_LITTLE (key + 16),	 U8TO32_LITTLE (key + 20),
+	      U8TO32_LITTLE (key + 24),	 U8TO32_LITTLE (key + 28),
+	      U8TO32_LITTLE (nonce + 0), U8TO32_LITTLE (nonce + 4),
+	      U8TO32_LITTLE (nonce + 8), U8TO32_LITTLE (nonce + 12) };
+
+  for (i = 20; i > 0; i -= 2)
+    {
+      QUARTERROUND (x[0], x[4], x[8], x[12])
+      QUARTERROUND (x[1], x[5], x[9], x[13])
+      QUARTERROUND (x[2], x[6], x[10], x[14])
+      QUARTERROUND (x[3], x[7], x[11], x[15])
+      QUARTERROUND (x[0], x[5], x[10], x[15])
+      QUARTERROUND (x[1], x[6], x[11], x[12])
+      QUARTERROUND (x[2], x[7], x[8], x[13])
+      QUARTERROUND (x[3], x[4], x[9], x[14])
+    }
+
+  clib_memcpy (derived_key + 0, x + 0, sizeof (u32) * 4);
+  clib_memcpy (derived_key + 4, x + 12, sizeof (u32) * 4);
+}
+
+#endif /* __included_wg_hchacha20_h__ */
+
+/*
+ * fd.io coding-style-patch-verification: ON
+ *
+ * Local Variables:
+ * eval: (c-set-style "gnu")
+ * End:
+ */
diff --git a/src/plugins/wireguard/wireguard_input.c b/src/plugins/wireguard/wireguard_input.c
index 3eba9cb..ef60d50 100644
--- a/src/plugins/wireguard/wireguard_input.c
+++ b/src/plugins/wireguard/wireguard_input.c
@@ -31,6 +31,7 @@
   _ (KEEPALIVE_SEND, "Failed while sending Keepalive")                        \
   _ (HANDSHAKE_SEND, "Failed while sending Handshake")                        \
   _ (HANDSHAKE_RECEIVE, "Failed while receiving Handshake")                   \
+  _ (COOKIE_DECRYPTION, "Failed during Cookie decryption")                    \
   _ (TOO_BIG, "Packet too big")                                               \
   _ (UNDEFINED, "Undefined error")                                            \
   _ (CRYPTO_ENGINE_ERROR, "crypto engine error (packet dropped)")
@@ -185,7 +186,9 @@
       else
 	return WG_INPUT_ERROR_PEER;
 
-      // TODO: Implement cookie_maker_consume_payload
+      if (!cookie_maker_consume_payload (
+	    vm, &peer->cookie_maker, packet->nonce, packet->encrypted_cookie))
+	return WG_INPUT_ERROR_COOKIE_DECRYPTION;
 
       return WG_INPUT_ERROR_NONE;
     }
diff --git a/src/plugins/wireguard/wireguard_noise.c b/src/plugins/wireguard/wireguard_noise.c
index 9c6e65c..c9d8e31 100644
--- a/src/plugins/wireguard/wireguard_noise.c
+++ b/src/plugins/wireguard/wireguard_noise.c
@@ -17,6 +17,7 @@
 
 #include <openssl/hmac.h>
 #include <wireguard/wireguard.h>
+#include <wireguard/wireguard_chachapoly.h>
 
 /* This implements Noise_IKpsk2:
  *
@@ -67,8 +68,6 @@
 
 static void noise_tai64n_now (uint8_t[NOISE_TIMESTAMP_LEN]);
 
-static void secure_zero_memory (void *v, size_t n);
-
 /* Set/Get noise parameters */
 void
 noise_local_init (noise_local_t * l, struct noise_upcall *upcall)
@@ -110,7 +109,7 @@
     clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
 
   noise_remote_handshake_index_drop (r);
-  secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
+  wg_secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
 }
 
 /* Handshake functions */
@@ -161,7 +160,7 @@
   *s_idx = hs->hs_local_index;
   ret = true;
 error:
-  secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
+  wg_secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
   vnet_crypto_key_del (vm, key_idx);
   return ret;
 }
@@ -244,9 +243,9 @@
   ret = true;
 
 error:
-  secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
+  wg_secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
   vnet_crypto_key_del (vm, key_idx);
-  secure_zero_memory (&hs, sizeof (hs));
+  wg_secure_zero_memory (&hs, sizeof (hs));
   return ret;
 }
 
@@ -297,9 +296,9 @@
   *s_idx = hs->hs_local_index;
   ret = true;
 error:
-  secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
+  wg_secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
   vnet_crypto_key_del (vm, key_idx);
-  secure_zero_memory (e, NOISE_PUBLIC_KEY_LEN);
+  wg_secure_zero_memory (e, NOISE_PUBLIC_KEY_LEN);
   return ret;
 }
 
@@ -358,8 +357,8 @@
       ret = true;
     }
 error:
-  secure_zero_memory (&hs, sizeof (hs));
-  secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
+  wg_secure_zero_memory (&hs, sizeof (hs));
+  wg_secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
   vnet_crypto_key_del (vm, key_idx);
   return ret;
 }
@@ -443,9 +442,9 @@
   vlib_worker_thread_barrier_release (vm);
   clib_rwlock_writer_unlock (&r->r_keypair_lock);
 
-  secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
+  wg_secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
 
-  secure_zero_memory (&kp, sizeof (kp));
+  wg_secure_zero_memory (&kp, sizeof (kp));
   return true;
 }
 
@@ -453,7 +452,7 @@
 noise_remote_clear (vlib_main_t * vm, noise_remote_t * r)
 {
   noise_remote_handshake_index_drop (r);
-  secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
+  wg_secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
 
   clib_rwlock_writer_lock (&r->r_keypair_lock);
   noise_remote_keypair_free (vm, r, &r->r_next);
@@ -495,55 +494,6 @@
   return ret;
 }
 
-static bool
-chacha20poly1305_calc (vlib_main_t * vm,
-		       u8 * src,
-		       u32 src_len,
-		       u8 * dst,
-		       u8 * aad,
-		       u32 aad_len,
-		       u64 nonce,
-		       vnet_crypto_op_id_t op_id,
-		       vnet_crypto_key_index_t key_index)
-{
-  vnet_crypto_op_t _op, *op = &_op;
-  u8 iv[12];
-  u8 tag_[NOISE_AUTHTAG_LEN] = { };
-  u8 src_[] = { };
-
-  clib_memset (iv, 0, 12);
-  clib_memcpy (iv + 4, &nonce, sizeof (nonce));
-
-  vnet_crypto_op_init (op, op_id);
-
-  op->tag_len = NOISE_AUTHTAG_LEN;
-  if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC)
-    {
-      op->tag = src + src_len - NOISE_AUTHTAG_LEN;
-      src_len -= NOISE_AUTHTAG_LEN;
-      op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK;
-    }
-  else
-    op->tag = tag_;
-
-  op->src = !src ? src_ : src;
-  op->len = src_len;
-
-  op->dst = dst;
-  op->key_index = key_index;
-  op->aad = aad;
-  op->aad_len = aad_len;
-  op->iv = iv;
-
-  vnet_crypto_process_ops (vm, op, 1);
-  if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC)
-    {
-      clib_memcpy (dst + src_len, op->tag, NOISE_AUTHTAG_LEN);
-    }
-
-  return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED);
-}
-
 enum noise_state_crypt
 noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
 		      uint64_t * nonce, uint8_t * src, size_t srclen,
@@ -572,9 +522,9 @@
    * are passed back out to the caller through the provided data pointer. */
   *r_idx = kp->kp_remote_index;
 
-  chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, *nonce,
-			 VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC,
-			 kp->kp_send_index);
+  wg_chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, *nonce,
+			    VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC,
+			    kp->kp_send_index);
 
   /* If our values are still within tolerances, but we are approaching
    * the tolerances, we notify the caller with ESTALE that they should
@@ -666,8 +616,8 @@
 
 out:
   /* Clear sensitive data from stack */
-  secure_zero_memory (sec, BLAKE2S_HASH_SIZE);
-  secure_zero_memory (out, BLAKE2S_HASH_SIZE + 1);
+  wg_secure_zero_memory (sec, BLAKE2S_HASH_SIZE);
+  wg_secure_zero_memory (out, BLAKE2S_HASH_SIZE + 1);
 }
 
 static bool
@@ -682,7 +632,7 @@
   noise_kdf (ck, key, NULL, dh,
 	     NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
 	     ck);
-  secure_zero_memory (dh, NOISE_PUBLIC_KEY_LEN);
+  wg_secure_zero_memory (dh, NOISE_PUBLIC_KEY_LEN);
   return true;
 }
 
@@ -723,7 +673,7 @@
 	     NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN,
 	     NOISE_SYMMETRIC_KEY_LEN, ck);
   noise_mix_hash (hash, tmp, NOISE_HASH_LEN);
-  secure_zero_memory (tmp, NOISE_HASH_LEN);
+  wg_secure_zero_memory (tmp, NOISE_HASH_LEN);
 }
 
 static void
@@ -750,8 +700,8 @@
 		   uint8_t hash[NOISE_HASH_LEN])
 {
   /* Nonce always zero for Noise_IK */
-  chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
-			 VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC, key_idx);
+  wg_chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
+			    VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC, key_idx);
   noise_mix_hash (hash, dst, src_len + NOISE_AUTHTAG_LEN);
 }
 
@@ -761,8 +711,9 @@
 		   uint8_t hash[NOISE_HASH_LEN])
 {
   /* Nonce always zero for Noise_IK */
-  if (!chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
-			      VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx))
+  if (!wg_chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN,
+				 0, VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC,
+				 key_idx))
     return false;
   noise_mix_hash (hash, src, src_len);
   return true;
@@ -800,13 +751,6 @@
   clib_memcpy (output + sizeof (sec), &nsec, sizeof (nsec));
 }
 
-static void
-secure_zero_memory (void *v, size_t n)
-{
-  static void *(*const volatile memset_v) (void *, int, size_t) = &memset;
-  memset_v (v, 0, n);
-}
-
 /*
  * fd.io coding-style-patch-verification: ON
  *
diff --git a/src/plugins/wireguard/wireguard_timer.h b/src/plugins/wireguard/wireguard_timer.h
index 9d5c071..ebde47e 100644
--- a/src/plugins/wireguard/wireguard_timer.h
+++ b/src/plugins/wireguard/wireguard_timer.h
@@ -57,6 +57,8 @@
 static inline bool
 wg_birthdate_has_expired (f64 birthday_seconds, f64 expiration_seconds)
 {
+  if (birthday_seconds == 0.0)
+    return true;
   f64 now_seconds = vlib_time_now (vlib_get_main ());
   return (birthday_seconds + expiration_seconds) < now_seconds;
 }
diff --git a/test/requirements-3.txt b/test/requirements-3.txt
index 2e8f17d..64da933 100644
--- a/test/requirements-3.txt
+++ b/test/requirements-3.txt
@@ -310,6 +310,38 @@
     --hash=sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9 \
     --hash=sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206
     # via cffi
+pycryptodome==3.15.0 \
+    --hash=sha256:045d75527241d17e6ef13636d845a12e54660aa82e823b3b3341bcf5af03fa79 \
+    --hash=sha256:0926f7cc3735033061ef3cf27ed16faad6544b14666410727b31fea85a5b16eb \
+    --hash=sha256:092a26e78b73f2530b8bd6b3898e7453ab2f36e42fd85097d705d6aba2ec3e5e \
+    --hash=sha256:1b22bcd9ec55e9c74927f6b1f69843cb256fb5a465088ce62837f793d9ffea88 \
+    --hash=sha256:2aa55aae81f935a08d5a3c2042eb81741a43e044bd8a81ea7239448ad751f763 \
+    --hash=sha256:2ea63d46157386c5053cfebcdd9bd8e0c8b7b0ac4a0507a027f5174929403884 \
+    --hash=sha256:2ec709b0a58b539a4f9d33fb8508264c3678d7edb33a68b8906ba914f71e8c13 \
+    --hash=sha256:2ffd8b31561455453ca9f62cb4c24e6b8d119d6d531087af5f14b64bee2c23e6 \
+    --hash=sha256:4b52cb18b0ad46087caeb37a15e08040f3b4c2d444d58371b6f5d786d95534c2 \
+    --hash=sha256:4c3ccad74eeb7b001f3538643c4225eac398c77d617ebb3e57571a897943c667 \
+    --hash=sha256:5099c9ca345b2f252f0c28e96904643153bae9258647585e5e6f649bb7a1844a \
+    --hash=sha256:57f565acd2f0cf6fb3e1ba553d0cb1f33405ec1f9c5ded9b9a0a5320f2c0bd3d \
+    --hash=sha256:60b4faae330c3624cc5a546ba9cfd7b8273995a15de94ee4538130d74953ec2e \
+    --hash=sha256:7c9ed8aa31c146bef65d89a1b655f5f4eab5e1120f55fc297713c89c9e56ff0b \
+    --hash=sha256:7e3a8f6ee405b3bd1c4da371b93c31f7027944b2bcce0697022801db93120d83 \
+    --hash=sha256:9135dddad504592bcc18b0d2d95ce86c3a5ea87ec6447ef25cfedea12d6018b8 \
+    --hash=sha256:9c772c485b27967514d0df1458b56875f4b6d025566bf27399d0c239ff1b369f \
+    --hash=sha256:9eaadc058106344a566dc51d3d3a758ab07f8edde013712bc8d22032a86b264f \
+    --hash=sha256:9ee40e2168f1348ae476676a2e938ca80a2f57b14a249d8fe0d3cdf803e5a676 \
+    --hash=sha256:a8f06611e691c2ce45ca09bbf983e2ff2f8f4f87313609d80c125aff9fad6e7f \
+    --hash=sha256:b9c5b1a1977491533dfd31e01550ee36ae0249d78aae7f632590db833a5012b8 \
+    --hash=sha256:b9cc96e274b253e47ad33ae1fccc36ea386f5251a823ccb50593a935db47fdd2 \
+    --hash=sha256:c3640deff4197fa064295aaac10ab49a0d55ef3d6a54ae1499c40d646655c89f \
+    --hash=sha256:c77126899c4b9c9827ddf50565e93955cb3996813c18900c16b2ea0474e130e9 \
+    --hash=sha256:d2a39a66057ab191e5c27211a7daf8f0737f23acbf6b3562b25a62df65ffcb7b \
+    --hash=sha256:e244ab85c422260de91cda6379e8e986405b4f13dc97d2876497178707f87fc1 \
+    --hash=sha256:ecaaef2d21b365d9c5ca8427ffc10cebed9d9102749fd502218c23cb9a05feb5 \
+    --hash=sha256:fd2184aae6ee2a944aaa49113e6f5787cdc5e4db1eb8edb1aea914bd75f33a0c \
+    --hash=sha256:ff287bcba9fbeb4f1cccc1f2e90a08d691480735a611ee83c80a7d74ad72b9d9 \
+    --hash=sha256:ff7ae90e36c1715a54446e7872b76102baa5c63aa980917f4aa45e8c78d1a3ec
+    # via -r requirements.txt
 pyenchant==3.2.2 \
     --hash=sha256:1cf830c6614362a78aab78d50eaf7c6c93831369c52e1bb64ffae1df0341e637 \
     --hash=sha256:5a636832987eaf26efe971968f4d1b78e81f62bca2bde0a9da210c7de43c3bce \
diff --git a/test/requirements.txt b/test/requirements.txt
index a177967..509fe89 100644
--- a/test/requirements.txt
+++ b/test/requirements.txt
@@ -20,3 +20,4 @@
 jsonschema; python_version >= '3.7'     # MIT
 dataclasses; python_version == '3.6'    # Apache-2.0
 black                                   # MIT https://github.com/psf/black
+pycryptodome                            # BSD, Public Domain
diff --git a/test/test_wireguard.py b/test/test_wireguard.py
index 8ab0cbc..7395402 100644
--- a/test/test_wireguard.py
+++ b/test/test_wireguard.py
@@ -16,6 +16,7 @@
     WireguardResponse,
     WireguardInitiation,
     WireguardTransport,
+    WireguardCookieReply,
 )
 from cryptography.hazmat.primitives.asymmetric.x25519 import (
     X25519PrivateKey,
@@ -32,6 +33,9 @@
 from cryptography.hazmat.backends import default_backend
 from noise.connection import NoiseConnection, Keypair
 
+from Crypto.Cipher import ChaCha20_Poly1305
+from Crypto.Random import get_random_bytes
+
 from vpp_ipip_tun_interface import VppIpIpTunInterface
 from vpp_interface import VppInterface
 from vpp_ip_route import VppIpRoute, VppRoutePath
@@ -56,6 +60,11 @@
     return k.public_bytes(Encoding.Raw, PublicFormat.Raw)
 
 
+def get_field_bytes(pkt, name):
+    fld, val = pkt.getfield_and_val(name)
+    return fld.i2m(pkt, val)
+
+
 class VppWgInterface(VppInterface):
     """
     VPP WireGuard interface
@@ -151,6 +160,10 @@
         self.private_key = X25519PrivateKey.generate()
         self.public_key = self.private_key.public_key()
 
+        # cookie related params
+        self.cookie_key = blake2s(b"cookie--" + self.public_key_bytes()).digest()
+        self.last_sent_cookie = None
+
         self.noise = NoiseConnection.from_name(NOISE_HANDSHAKE_NAME)
 
     def add_vpp_config(self, is_ip6=False):
@@ -199,9 +212,6 @@
                 return True
         return False
 
-    def set_responder(self):
-        self.noise.set_as_responder()
-
     def mk_tunnel_header(self, tx_itf, is_ip6=False):
         if is_ip6 is False:
             return (
@@ -234,6 +244,55 @@
 
         self.noise.start_handshake()
 
+    def mk_cookie(self, p, tx_itf, is_resp=False, is_ip6=False):
+        self.verify_header(p, is_ip6)
+
+        wg_pkt = Wireguard(p[Raw])
+
+        if is_resp:
+            self._test.assertEqual(wg_pkt[Wireguard].message_type, 2)
+            self._test.assertEqual(wg_pkt[Wireguard].reserved_zero, 0)
+            self._test.assertEqual(wg_pkt[WireguardResponse].mac2, bytes([0] * 16))
+        else:
+            self._test.assertEqual(wg_pkt[Wireguard].message_type, 1)
+            self._test.assertEqual(wg_pkt[Wireguard].reserved_zero, 0)
+            self._test.assertEqual(wg_pkt[WireguardInitiation].mac2, bytes([0] * 16))
+
+        # collect info from wg packet (initiation or response)
+        src = get_field_bytes(p[IPv6 if is_ip6 else IP], "src")
+        sport = p[UDP].sport.to_bytes(2, byteorder="big")
+        if is_resp:
+            mac1 = wg_pkt[WireguardResponse].mac1
+            sender_index = wg_pkt[WireguardResponse].sender_index
+        else:
+            mac1 = wg_pkt[WireguardInitiation].mac1
+            sender_index = wg_pkt[WireguardInitiation].sender_index
+
+        # make cookie reply
+        cookie_reply = Wireguard() / WireguardCookieReply()
+        cookie_reply[Wireguard].message_type = 3
+        cookie_reply[Wireguard].reserved_zero = 0
+        cookie_reply[WireguardCookieReply].receiver_index = sender_index
+        nonce = get_random_bytes(24)
+        cookie_reply[WireguardCookieReply].nonce = nonce
+
+        # generate cookie data
+        changing_secret = get_random_bytes(32)
+        self.last_sent_cookie = blake2s(
+            src + sport, digest_size=16, key=changing_secret
+        ).digest()
+
+        # encrypt cookie data
+        cipher = ChaCha20_Poly1305.new(key=self.cookie_key, nonce=nonce)
+        cipher.update(mac1)
+        ciphertext, tag = cipher.encrypt_and_digest(self.last_sent_cookie)
+        cookie_reply[WireguardCookieReply].encrypted_cookie = ciphertext + tag
+
+        # prepare cookie reply to be sent
+        cookie_reply = self.mk_tunnel_header(tx_itf, is_ip6) / cookie_reply
+
+        return cookie_reply
+
     def mk_handshake(self, tx_itf, is_ip6=False, public_key=None):
         self.noise.set_as_initiator()
         self.noise_init(public_key)
@@ -281,7 +340,7 @@
         self._test.assertEqual(p[UDP].dport, self.port)
         self._test.assert_packet_checksums_valid(p)
 
-    def consume_init(self, p, tx_itf, is_ip6=False):
+    def consume_init(self, p, tx_itf, is_ip6=False, is_mac2=False):
         self.noise.set_as_responder()
         self.noise_init(self.itf.public_key)
         self.verify_header(p, is_ip6)
@@ -293,11 +352,23 @@
 
         self.sender = init[WireguardInitiation].sender_index
 
-        # validate the hash
+        # validate the mac1 hash
         mac_key = blake2s(b"mac1----" + public_key_bytes(self.public_key)).digest()
         mac1 = blake2s(bytes(init)[0:-32], digest_size=16, key=mac_key).digest()
         self._test.assertEqual(init[WireguardInitiation].mac1, mac1)
 
+        # validate the mac2 hash
+        if is_mac2:
+            self._test.assertNotEqual(init[WireguardInitiation].mac2, bytes([0] * 16))
+            self._test.assertNotEqual(self.last_sent_cookie, None)
+            mac2 = blake2s(
+                bytes(init)[0:-16], digest_size=16, key=self.last_sent_cookie
+            ).digest()
+            self._test.assertEqual(init[WireguardInitiation].mac2, mac2)
+            self.last_sent_cookie = None
+        else:
+            self._test.assertEqual(init[WireguardInitiation].mac2, bytes([0] * 16))
+
         # this passes only unencrypted_ephemeral, encrypted_static,
         # encrypted_timestamp fields of the init
         payload = self.noise.read_message(bytes(init)[8:-32])
@@ -398,6 +469,8 @@
     mac6_error = wg6_input_node_name + "Invalid MAC handshake"
     peer6_in_err = wg6_input_node_name + "Peer error"
     peer6_out_err = wg6_output_node_name + "Peer error"
+    cookie_dec4_err = wg4_input_node_name + "Failed during Cookie decryption"
+    cookie_dec6_err = wg6_input_node_name + "Failed during Cookie decryption"
 
     @classmethod
     def setUpClass(cls):
@@ -429,6 +502,12 @@
         self.base_mac6_err = self.statistics.get_err_counter(self.mac6_error)
         self.base_peer6_in_err = self.statistics.get_err_counter(self.peer6_in_err)
         self.base_peer6_out_err = self.statistics.get_err_counter(self.peer6_out_err)
+        self.base_cookie_dec4_err = self.statistics.get_err_counter(
+            self.cookie_dec4_err
+        )
+        self.base_cookie_dec6_err = self.statistics.get_err_counter(
+            self.cookie_dec6_err
+        )
 
     def test_wg_interface(self):
         """Simple interface creation"""
@@ -485,6 +564,88 @@
 
         self.assertEqual(tgt, act)
 
+    def _test_wg_send_cookie_tmpl(self, is_resp, is_ip6):
+        port = 12323
+
+        # create wg interface
+        if is_ip6:
+            wg0 = VppWgInterface(self, self.pg1.local_ip6, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip6()
+        else:
+            wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip4()
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # create a peer
+        if is_ip6:
+            peer_1 = VppWgPeer(
+                self, wg0, self.pg1.remote_ip6, port + 1, ["1::3:0/112"]
+            ).add_vpp_config()
+        else:
+            peer_1 = VppWgPeer(
+                self, wg0, self.pg1.remote_ip4, port + 1, ["10.11.3.0/24"]
+            ).add_vpp_config()
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
+
+        if is_resp:
+            # prepare and send a handshake initiation
+            # expect the peer to send a handshake response
+            init = peer_1.mk_handshake(self.pg1, is_ip6=is_ip6)
+            rxs = self.send_and_expect(self.pg1, [init], self.pg1)
+        else:
+            # wait for the peer to send a handshake initiation
+            rxs = self.pg1.get_capture(1, timeout=2)
+
+        # prepare and send a wrong cookie reply
+        # expect no replies and the cookie error incremented
+        cookie = peer_1.mk_cookie(rxs[0], self.pg1, is_resp=is_resp, is_ip6=is_ip6)
+        cookie.nonce = b"1234567890"
+        self.send_and_assert_no_replies(self.pg1, [cookie], timeout=0.1)
+        if is_ip6:
+            self.assertEqual(
+                self.base_cookie_dec6_err + 1,
+                self.statistics.get_err_counter(self.cookie_dec6_err),
+            )
+        else:
+            self.assertEqual(
+                self.base_cookie_dec4_err + 1,
+                self.statistics.get_err_counter(self.cookie_dec4_err),
+            )
+
+        # prepare and send a correct cookie reply
+        cookie = peer_1.mk_cookie(rxs[0], self.pg1, is_resp=is_resp, is_ip6=is_ip6)
+        self.pg_send(self.pg1, [cookie])
+
+        # wait for the peer to send a handshake initiation with mac2 set
+        rxs = self.pg1.get_capture(1, timeout=6)
+
+        # verify the initiation and its mac2
+        peer_1.consume_init(rxs[0], self.pg1, is_ip6=is_ip6, is_mac2=True)
+
+        # remove configs
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
+
+    def test_wg_send_cookie_on_init_v4(self):
+        """Send cookie on handshake initiation (v4)"""
+        self._test_wg_send_cookie_tmpl(is_resp=False, is_ip6=False)
+
+    def test_wg_send_cookie_on_init_v6(self):
+        """Send cookie on handshake initiation (v6)"""
+        self._test_wg_send_cookie_tmpl(is_resp=False, is_ip6=True)
+
+    def test_wg_send_cookie_on_resp_v4(self):
+        """Send cookie on handshake response (v4)"""
+        self._test_wg_send_cookie_tmpl(is_resp=True, is_ip6=False)
+
+    def test_wg_send_cookie_on_resp_v6(self):
+        """Send cookie on handshake response (v6)"""
+        self._test_wg_send_cookie_tmpl(is_resp=True, is_ip6=True)
+
     def test_wg_peer_resp(self):
         """Send handshake response"""
         port = 12323