wireguard: add handshake rate limiting support

Type: feature

With this change, if being under load a handshake message with both
valid mac1 and mac2 is received, the peer will be rate limited. Cover
this with tests.

Signed-off-by: Alexander Chernavin <achernavin@netgate.com>
Change-Id: Id8d58bb293a7975c3d922c48b4948fd25e20af4b
diff --git a/src/plugins/wireguard/FEATURE.yaml b/src/plugins/wireguard/FEATURE.yaml
index cf8b6d7..4c6946d 100644
--- a/src/plugins/wireguard/FEATURE.yaml
+++ b/src/plugins/wireguard/FEATURE.yaml
@@ -8,5 +8,4 @@
 state: development
 properties: [API, CLI]
 missing:
-  - IPv6 support
-  - DoS protection as in the original protocol
+  - Peers roaming between different external IPs
diff --git a/src/plugins/wireguard/README.rst b/src/plugins/wireguard/README.rst
index cb7a024..ead4125 100644
--- a/src/plugins/wireguard/README.rst
+++ b/src/plugins/wireguard/README.rst
@@ -77,5 +77,4 @@
 -------------------------------------------------
 
 1. Use all benefits of VPP-engine.
-2. Add IPv6 support (currently only supports IPv4)
-3. Add DoS protection as in original protocol (using cookie)
+2. Add peers roaming support
diff --git a/src/plugins/wireguard/wireguard_cookie.c b/src/plugins/wireguard/wireguard_cookie.c
index 595b877..4ebbfa0 100644
--- a/src/plugins/wireguard/wireguard_cookie.c
+++ b/src/plugins/wireguard/wireguard_cookie.c
@@ -34,6 +34,11 @@
 					uint8_t[COOKIE_COOKIE_SIZE],
 					ip46_address_t *ip, u16 udp_port);
 
+static void ratelimit_init (ratelimit_t *, ratelimit_entry_t *);
+static void ratelimit_deinit (ratelimit_t *);
+static void ratelimit_gc (ratelimit_t *, bool);
+static bool ratelimit_allow (ratelimit_t *, ip46_address_t *);
+
 /* Public Functions */
 void
 cookie_maker_init (cookie_maker_t * cp, const uint8_t key[COOKIE_INPUT_SIZE])
@@ -44,6 +49,14 @@
 }
 
 void
+cookie_checker_init (cookie_checker_t *cc, ratelimit_entry_t *pool)
+{
+  clib_memset (cc, 0, sizeof (*cc));
+  ratelimit_init (&cc->cc_ratelimit_v4, pool);
+  ratelimit_init (&cc->cc_ratelimit_v6, pool);
+}
+
+void
 cookie_checker_update (cookie_checker_t * cc, uint8_t key[COOKIE_INPUT_SIZE])
 {
   if (key)
@@ -59,6 +72,13 @@
 }
 
 void
+cookie_checker_deinit (cookie_checker_t *cc)
+{
+  ratelimit_deinit (&cc->cc_ratelimit_v4);
+  ratelimit_deinit (&cc->cc_ratelimit_v6);
+}
+
+void
 cookie_checker_create_payload (vlib_main_t *vm, cookie_checker_t *cc,
 			       message_macs_t *cm,
 			       uint8_t nonce[COOKIE_NONCE_SIZE],
@@ -146,6 +166,13 @@
   if (clib_memcmp (our_cm.mac2, cm->mac2, COOKIE_MAC_SIZE) != 0)
     return VALID_MAC_BUT_NO_COOKIE;
 
+  /* If the mac2 is valid, we may want to rate limit the peer */
+  ratelimit_t *rl;
+  rl = ip46_address_is_ip4 (ip) ? &cc->cc_ratelimit_v4 : &cc->cc_ratelimit_v6;
+
+  if (!ratelimit_allow (rl, ip))
+    return VALID_MAC_WITH_COOKIE_BUT_RATELIMITED;
+
   return VALID_MAC_WITH_COOKIE;
 }
 
@@ -213,6 +240,126 @@
   blake2s_final (&state, cookie, COOKIE_COOKIE_SIZE);
 }
 
+static void
+ratelimit_init (ratelimit_t *rl, ratelimit_entry_t *pool)
+{
+  rl->rl_pool = pool;
+}
+
+static void
+ratelimit_deinit (ratelimit_t *rl)
+{
+  ratelimit_gc (rl, /* force */ true);
+  hash_free (rl->rl_table);
+}
+
+static void
+ratelimit_gc (ratelimit_t *rl, bool force)
+{
+  u32 r_key;
+  u32 r_idx;
+  ratelimit_entry_t *r;
+
+  if (force)
+    {
+      /* clang-format off */
+      hash_foreach (r_key, r_idx, rl->rl_table, {
+	r = pool_elt_at_index (rl->rl_pool, r_idx);
+	pool_put (rl->rl_pool, r);
+      });
+      /* clang-format on */
+      return;
+    }
+
+  f64 now = vlib_time_now (vlib_get_main ());
+
+  if ((rl->rl_last_gc + ELEMENT_TIMEOUT) < now)
+    {
+      u32 *r_key_to_del = NULL;
+      u32 *pr_key;
+
+      rl->rl_last_gc = now;
+
+      /* clang-format off */
+      hash_foreach (r_key, r_idx, rl->rl_table, {
+	r = pool_elt_at_index (rl->rl_pool, r_idx);
+	if ((r->r_last_time + ELEMENT_TIMEOUT) < now)
+	  {
+	    vec_add1 (r_key_to_del, r_key);
+	    pool_put (rl->rl_pool, r);
+	  }
+      });
+      /* clang-format on */
+
+      vec_foreach (pr_key, r_key_to_del)
+	{
+	  hash_unset (rl->rl_table, *pr_key);
+	}
+
+      vec_free (r_key_to_del);
+    }
+}
+
+static bool
+ratelimit_allow (ratelimit_t *rl, ip46_address_t *ip)
+{
+  u32 r_key;
+  uword *p;
+  u32 r_idx;
+  ratelimit_entry_t *r;
+  f64 now = vlib_time_now (vlib_get_main ());
+
+  if (ip46_address_is_ip4 (ip))
+    /* Use all 4 bytes of IPv4 address */
+    r_key = ip->ip4.as_u32;
+  else
+    /* Use top 8 bytes (/64) of IPv6 address */
+    r_key = ip->ip6.as_u32[0] ^ ip->ip6.as_u32[1];
+
+  /* Check if there is already an entry for the IP address */
+  p = hash_get (rl->rl_table, r_key);
+  if (p)
+    {
+      u64 tokens;
+      f64 diff;
+
+      r_idx = p[0];
+      r = pool_elt_at_index (rl->rl_pool, r_idx);
+
+      diff = now - r->r_last_time;
+      r->r_last_time = now;
+
+      tokens = r->r_tokens + diff * NSEC_PER_SEC;
+
+      if (tokens > TOKEN_MAX)
+	tokens = TOKEN_MAX;
+
+      if (tokens >= INITIATION_COST)
+	{
+	  r->r_tokens = tokens - INITIATION_COST;
+	  return true;
+	}
+
+      r->r_tokens = tokens;
+      return false;
+    }
+
+  /* No entry for the IP address */
+  ratelimit_gc (rl, /* force */ false);
+
+  if (hash_elts (rl->rl_table) >= RATELIMIT_SIZE_MAX)
+    return false;
+
+  pool_get (rl->rl_pool, r);
+  r_idx = r - rl->rl_pool;
+  hash_set (rl->rl_table, r_key, r_idx);
+
+  r->r_last_time = now;
+  r->r_tokens = TOKEN_MAX - INITIATION_COST;
+
+  return true;
+}
+
 /*
  * fd.io coding-style-patch-verification: ON
  *
diff --git a/src/plugins/wireguard/wireguard_cookie.h b/src/plugins/wireguard/wireguard_cookie.h
index 9298ece..7467cf2 100644
--- a/src/plugins/wireguard/wireguard_cookie.h
+++ b/src/plugins/wireguard/wireguard_cookie.h
@@ -25,7 +25,8 @@
 {
   INVALID_MAC,
   VALID_MAC_BUT_NO_COOKIE,
-  VALID_MAC_WITH_COOKIE
+  VALID_MAC_WITH_COOKIE,
+  VALID_MAC_WITH_COOKIE_BUT_RATELIMITED,
 };
 
 #define COOKIE_MAC_SIZE		16
@@ -50,8 +51,6 @@
 #define INITIATION_COST		(NSEC_PER_SEC / INITIATIONS_PER_SECOND)
 #define TOKEN_MAX		(INITIATION_COST * INITIATIONS_BURSTABLE)
 #define ELEMENT_TIMEOUT		1
-#define IPV4_MASK_SIZE		4	/* Use all 4 bytes of IPv4 address */
-#define IPV6_MASK_SIZE		8	/* Use top 8 bytes (/64) of IPv6 address */
 
 typedef struct cookie_macs
 {
@@ -59,6 +58,19 @@
   uint8_t mac2[COOKIE_MAC_SIZE];
 } message_macs_t;
 
+typedef struct ratelimit_entry
+{
+  f64 r_last_time;
+  u64 r_tokens;
+} ratelimit_entry_t;
+
+typedef struct ratelimit
+{
+  ratelimit_entry_t *rl_pool;
+  uword *rl_table;
+  f64 rl_last_gc;
+} ratelimit_t;
+
 typedef struct cookie_maker
 {
   uint8_t cp_mac1_key[COOKIE_KEY_SIZE];
@@ -72,6 +84,9 @@
 
 typedef struct cookie_checker
 {
+  ratelimit_t cc_ratelimit_v4;
+  ratelimit_t cc_ratelimit_v6;
+
   uint8_t cc_mac1_key[COOKIE_KEY_SIZE];
   uint8_t cc_cookie_key[COOKIE_KEY_SIZE];
 
@@ -81,7 +96,9 @@
 
 
 void cookie_maker_init (cookie_maker_t *, const uint8_t[COOKIE_INPUT_SIZE]);
+void cookie_checker_init (cookie_checker_t *, ratelimit_entry_t *);
 void cookie_checker_update (cookie_checker_t *, uint8_t[COOKIE_INPUT_SIZE]);
+void cookie_checker_deinit (cookie_checker_t *);
 void cookie_checker_create_payload (vlib_main_t *vm, cookie_checker_t *cc,
 				    message_macs_t *cm,
 				    uint8_t nonce[COOKIE_NONCE_SIZE],
diff --git a/src/plugins/wireguard/wireguard_if.c b/src/plugins/wireguard/wireguard_if.c
index c4199d2..a869df0 100644
--- a/src/plugins/wireguard/wireguard_if.c
+++ b/src/plugins/wireguard/wireguard_if.c
@@ -34,6 +34,9 @@
 /* vector of interfaces key'd on their UDP port (in network order) */
 index_t **wg_if_indexes_by_port;
 
+/* pool of ratelimit entries */
+static ratelimit_entry_t *wg_ratelimit_pool;
+
 static u8 *
 format_wg_if_name (u8 * s, va_list * args)
 {
@@ -309,6 +312,7 @@
 
   wg_if->port = port;
   wg_if->local_idx = local - noise_local_pool;
+  cookie_checker_init (&wg_if->cookie_checker, wg_ratelimit_pool);
   cookie_checker_update (&wg_if->cookie_checker, local->l_public);
 
   hw_if_index = vnet_register_interface (vnm,
@@ -372,6 +376,8 @@
       udp_unregister_dst_port (vlib_get_main (), wg_if->port, 0);
     }
 
+  cookie_checker_deinit (&wg_if->cookie_checker);
+
   vnet_reset_interface_l3_output_node (vnm->vlib_main, sw_if_index);
   vnet_delete_hw_interface (vnm, hw->hw_if_index);
   pool_put_index (noise_local_pool, wg_if->local_idx);
diff --git a/src/plugins/wireguard/wireguard_input.c b/src/plugins/wireguard/wireguard_input.c
index 3f546cc..b85cdc6 100644
--- a/src/plugins/wireguard/wireguard_input.c
+++ b/src/plugins/wireguard/wireguard_input.c
@@ -25,6 +25,7 @@
 #define foreach_wg_input_error                                                \
   _ (NONE, "No error")                                                        \
   _ (HANDSHAKE_MAC, "Invalid MAC handshake")                                  \
+  _ (HANDSHAKE_RATELIMITED, "Handshake ratelimited")                          \
   _ (PEER, "Peer error")                                                      \
   _ (INTERFACE, "Interface error")                                            \
   _ (DECRYPTION, "Failed during decryption")                                  \
@@ -232,6 +233,8 @@
     packet_needs_cookie = false;
   else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
     packet_needs_cookie = true;
+  else if (mac_state == VALID_MAC_WITH_COOKIE_BUT_RATELIMITED)
+    return WG_INPUT_ERROR_HANDSHAKE_RATELIMITED;
   else
     return WG_INPUT_ERROR_HANDSHAKE_MAC;
 
diff --git a/test/test_wireguard.py b/test/test_wireguard.py
index 564dee2..b8c5d2a 100644
--- a/test/test_wireguard.py
+++ b/test/test_wireguard.py
@@ -152,6 +152,7 @@
 HANDSHAKE_COUNTING_INTERVAL = 0.5
 UNDER_LOAD_INTERVAL = 1.0
 HANDSHAKE_NUM_PER_PEER_UNTIL_UNDER_LOAD = 40
+HANDSHAKE_NUM_BEFORE_RATELIMITING = 5
 
 
 class VppWgPeer(VppObject):
@@ -514,6 +515,8 @@
     peer6_out_err = wg6_output_node_name + "Peer error"
     cookie_dec4_err = wg4_input_node_name + "Failed during Cookie decryption"
     cookie_dec6_err = wg6_input_node_name + "Failed during Cookie decryption"
+    ratelimited4_err = wg4_input_node_name + "Handshake ratelimited"
+    ratelimited6_err = wg6_input_node_name + "Handshake ratelimited"
 
     @classmethod
     def setUpClass(cls):
@@ -551,6 +554,12 @@
         self.base_cookie_dec6_err = self.statistics.get_err_counter(
             self.cookie_dec6_err
         )
+        self.base_ratelimited4_err = self.statistics.get_err_counter(
+            self.ratelimited4_err
+        )
+        self.base_ratelimited6_err = self.statistics.get_err_counter(
+            self.ratelimited6_err
+        )
 
     def test_wg_interface(self):
         """Simple interface creation"""
@@ -829,6 +838,165 @@
         peer_1.remove_vpp_config()
         wg0.remove_vpp_config()
 
+    def _test_wg_handshake_ratelimiting_tmpl(self, is_ip6):
+        port = 12323
+
+        # create wg interface
+        if is_ip6:
+            wg0 = VppWgInterface(self, self.pg1.local_ip6, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip6()
+        else:
+            wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip4()
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # create a peer
+        if is_ip6:
+            peer_1 = VppWgPeer(
+                self, wg0, self.pg1.remote_ip6, port + 1, ["1::3:0/112"]
+            ).add_vpp_config()
+        else:
+            peer_1 = VppWgPeer(
+                self, wg0, self.pg1.remote_ip4, port + 1, ["10.11.3.0/24"]
+            ).add_vpp_config()
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
+
+        # prepare and send a bunch of handshake initiations
+        # expect to switch to under load state
+        init = peer_1.mk_handshake(self.pg1, is_ip6=is_ip6)
+        txs = [init] * HANDSHAKE_NUM_PER_PEER_UNTIL_UNDER_LOAD
+        rxs = self.send_and_expect_some(self.pg1, txs, self.pg1)
+
+        # expect the peer to send a cookie reply
+        peer_1.consume_cookie(rxs[-1], is_ip6=is_ip6)
+
+        # prepare and send a bunch of handshake initiations with correct mac2
+        # expect a handshake response and then ratelimiting
+        NUM_TO_REJECT = 10
+        init = peer_1.mk_handshake(self.pg1, is_ip6=is_ip6)
+        txs = [init] * (HANDSHAKE_NUM_BEFORE_RATELIMITING + NUM_TO_REJECT)
+        rxs = self.send_and_expect_some(self.pg1, txs, self.pg1)
+
+        if is_ip6:
+            self.assertEqual(
+                self.base_ratelimited6_err + NUM_TO_REJECT,
+                self.statistics.get_err_counter(self.ratelimited6_err),
+            )
+        else:
+            self.assertEqual(
+                self.base_ratelimited4_err + NUM_TO_REJECT,
+                self.statistics.get_err_counter(self.ratelimited4_err),
+            )
+
+        # verify the response
+        peer_1.consume_response(rxs[0], is_ip6=is_ip6)
+
+        # clear up under load state
+        self.sleep(UNDER_LOAD_INTERVAL)
+
+        # remove configs
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
+
+    def test_wg_handshake_ratelimiting_v4(self):
+        """Handshake ratelimiting (v4)"""
+        self._test_wg_handshake_ratelimiting_tmpl(is_ip6=False)
+
+    def test_wg_handshake_ratelimiting_v6(self):
+        """Handshake ratelimiting (v6)"""
+        self._test_wg_handshake_ratelimiting_tmpl(is_ip6=True)
+
+    def test_wg_handshake_ratelimiting_multi_peer(self):
+        """Handshake ratelimiting (multiple peer)"""
+        port = 12323
+
+        # create wg interface
+        wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config()
+        wg0.admin_up()
+        wg0.config_ip4()
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # create two peers
+        NUM_PEERS = 2
+        self.pg1.generate_remote_hosts(NUM_PEERS)
+        self.pg1.configure_ipv4_neighbors()
+
+        peer_1 = VppWgPeer(
+            self, wg0, self.pg1.remote_hosts[0].ip4, port + 1, ["10.11.3.0/24"]
+        ).add_vpp_config()
+        peer_2 = VppWgPeer(
+            self, wg0, self.pg1.remote_hosts[1].ip4, port + 1, ["10.11.4.0/24"]
+        ).add_vpp_config()
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 2)
+
+        # (peer_1) prepare and send a bunch of handshake initiations
+        # expect not to switch to under load state
+        init_1 = peer_1.mk_handshake(self.pg1)
+        txs = [init_1] * HANDSHAKE_NUM_PER_PEER_UNTIL_UNDER_LOAD
+        rxs = self.send_and_expect_some(self.pg1, txs, self.pg1)
+
+        # (peer_1) expect the peer to send a handshake response
+        peer_1.consume_response(rxs[0])
+        peer_1.noise_reset()
+
+        # (peer_1) send another bunch of handshake initiations
+        # expect to switch to under load state
+        rxs = self.send_and_expect_some(self.pg1, txs, self.pg1)
+
+        # (peer_1) expect the peer to send a cookie reply
+        peer_1.consume_cookie(rxs[-1])
+
+        # (peer_2) prepare and send a handshake initiation
+        # expect a cookie reply
+        init_2 = peer_2.mk_handshake(self.pg1)
+        rxs = self.send_and_expect(self.pg1, [init_2], self.pg1)
+        peer_2.consume_cookie(rxs[0])
+
+        # (peer_1) prepare and send a bunch of handshake initiations with correct mac2
+        # expect no ratelimiting and a handshake response
+        init_1 = peer_1.mk_handshake(self.pg1)
+        txs = [init_1] * HANDSHAKE_NUM_BEFORE_RATELIMITING
+        rxs = self.send_and_expect_some(self.pg1, txs, self.pg1)
+        self.assertEqual(
+            self.base_ratelimited4_err,
+            self.statistics.get_err_counter(self.ratelimited4_err),
+        )
+
+        # (peer_1) verify the response
+        peer_1.consume_response(rxs[0])
+        peer_1.noise_reset()
+
+        # (peer_1) send another two handshake initiations with correct mac2
+        # expect ratelimiting
+        # (peer_2) prepare and send a handshake initiation with correct mac2
+        # expect no ratelimiting and a handshake response
+        init_2 = peer_2.mk_handshake(self.pg1)
+        txs = [init_1, init_2, init_1]
+        rxs = self.send_and_expect_some(self.pg1, txs, self.pg1)
+
+        # (peer_1) verify ratelimiting
+        self.assertEqual(
+            self.base_ratelimited4_err + 2,
+            self.statistics.get_err_counter(self.ratelimited4_err),
+        )
+
+        # (peer_2) verify the response
+        peer_2.consume_response(rxs[0])
+
+        # clear up under load state
+        self.sleep(UNDER_LOAD_INTERVAL)
+
+        # remove configs
+        peer_1.remove_vpp_config()
+        peer_2.remove_vpp_config()
+        wg0.remove_vpp_config()
+
     def test_wg_peer_resp(self):
         """Send handshake response"""
         port = 12323