ipsec: Reference count the SAs

- this remove the need to iterate through all state when deleting an SA
- and ensures that if the SA is deleted by the client is remains for use
in any state until that state is also removed.

Type: feature

Change-Id: I438cb67588cb65c701e49a7a9518f88641925419
Signed-off-by: Neale Ranns <nranns@cisco.com>
diff --git a/src/plugins/unittest/ipsec_test.c b/src/plugins/unittest/ipsec_test.c
index ec39a2e..c40e954 100644
--- a/src/plugins/unittest/ipsec_test.c
+++ b/src/plugins/unittest/ipsec_test.c
@@ -42,11 +42,13 @@
       ipsec_sa_t *sa;
       u32 sa_index;
 
-      sa_index = ipsec_get_sa_index_by_sa_id (sa_id);
+      sa_index = ipsec_sa_find_and_lock (sa_id);
       sa = pool_elt_at_index (im->sad, sa_index);
 
       sa->seq = seq_num & 0xffffffff;
       sa->seq_hi = seq_num >> 32;
+
+      ipsec_sa_unlock (sa_index);
     }
   else
     {
diff --git a/src/vnet/ipsec/ipsec_api.c b/src/vnet/ipsec/ipsec_api.c
index 6de0203..371e4fe 100644
--- a/src/vnet/ipsec/ipsec_api.c
+++ b/src/vnet/ipsec/ipsec_api.c
@@ -513,12 +513,13 @@
   ip_address_decode (&mp->entry.tunnel_dst, &tun_dst);
 
   if (mp->is_add)
-    rv = ipsec_sa_add (id, spi, proto,
-		       crypto_alg, &crypto_key,
-		       integ_alg, &integ_key, flags,
-		       0, mp->entry.salt, &tun_src, &tun_dst, &sa_index);
+    rv = ipsec_sa_add_and_lock (id, spi, proto,
+				crypto_alg, &crypto_key,
+				integ_alg, &integ_key, flags,
+				0, mp->entry.salt, &tun_src, &tun_dst,
+				&sa_index);
   else
-    rv = ipsec_sa_del (id);
+    rv = ipsec_sa_unlock_id (id);
 
 #else
   rv = VNET_API_ERROR_UNIMPLEMENTED;
diff --git a/src/vnet/ipsec/ipsec_cli.c b/src/vnet/ipsec/ipsec_cli.c
index 60b9244..a5972bb 100644
--- a/src/vnet/ipsec/ipsec_cli.c
+++ b/src/vnet/ipsec/ipsec_cli.c
@@ -144,12 +144,12 @@
     }
 
   if (is_add)
-    rv = ipsec_sa_add (id, spi, proto, crypto_alg,
-		       &ck, integ_alg, &ik, flags,
-		       0, clib_host_to_net_u32 (salt),
-		       &tun_src, &tun_dst, NULL);
+    rv = ipsec_sa_add_and_lock (id, spi, proto, crypto_alg,
+				&ck, integ_alg, &ik, flags,
+				0, clib_host_to_net_u32 (salt),
+				&tun_src, &tun_dst, NULL);
   else
-    rv = ipsec_sa_del (id);
+    rv = ipsec_sa_unlock_id (id);
 
   if (rv)
     error = clib_error_return (0, "failed");
diff --git a/src/vnet/ipsec/ipsec_format.c b/src/vnet/ipsec/ipsec_format.c
index a0cd5ad..0d596c0 100644
--- a/src/vnet/ipsec/ipsec_format.c
+++ b/src/vnet/ipsec/ipsec_format.c
@@ -285,8 +285,8 @@
 
   sa = pool_elt_at_index (im->sad, sai);
 
-  s = format (s, "[%d] sa 0x%x spi %u (0x%08x) mode %s%s protocol %s %U",
-	      sai, sa->id, sa->spi, sa->spi,
+  s = format (s, "[%d] sa %d (0x%x) spi %u (0x%08x) mode %s%s protocol %s %U",
+	      sai, sa->id, sa->id, sa->spi, sa->spi,
 	      ipsec_sa_is_set_IS_TUNNEL (sa) ? "tunnel" : "transport",
 	      ipsec_sa_is_set_IS_TUNNEL_V6 (sa) ? "-ip6" : "",
 	      sa->protocol ? "esp" : "ah", format_ipsec_sa_flags, sa->flags);
@@ -294,6 +294,7 @@
   if (!(flags & IPSEC_FORMAT_DETAIL))
     goto done;
 
+  s = format (s, "\n   locks %d", sa->node.fn_locks);
   s = format (s, "\n   salt 0x%x", clib_net_to_host_u32 (sa->salt));
   s = format (s, "\n   seq %u seq-hi %u", sa->seq, sa->seq_hi);
   s = format (s, "\n   last-seq %u last-seq-hi %u window %U",
diff --git a/src/vnet/ipsec/ipsec_if.c b/src/vnet/ipsec/ipsec_if.c
index 8e4f3f1..5fc49e1 100644
--- a/src/vnet/ipsec/ipsec_if.c
+++ b/src/vnet/ipsec/ipsec_if.c
@@ -321,18 +321,18 @@
       ipsec_mk_key (&integ_key,
 		    args->remote_integ_key, args->remote_integ_key_len);
 
-      rv = ipsec_sa_add (ipsec_tun_mk_input_sa_id (dev_instance),
-			 args->remote_spi,
-			 IPSEC_PROTOCOL_ESP,
-			 args->crypto_alg,
-			 &crypto_key,
-			 args->integ_alg,
-			 &integ_key,
-			 (flags | IPSEC_SA_FLAG_IS_INBOUND),
-			 args->tx_table_id,
-			 args->salt,
-			 &args->remote_ip,
-			 &args->local_ip, &t->input_sa_index);
+      rv = ipsec_sa_add_and_lock (ipsec_tun_mk_input_sa_id (dev_instance),
+				  args->remote_spi,
+				  IPSEC_PROTOCOL_ESP,
+				  args->crypto_alg,
+				  &crypto_key,
+				  args->integ_alg,
+				  &integ_key,
+				  (flags | IPSEC_SA_FLAG_IS_INBOUND),
+				  args->tx_table_id,
+				  args->salt,
+				  &args->remote_ip,
+				  &args->local_ip, &t->input_sa_index);
 
       if (rv)
 	return rv;
@@ -342,18 +342,18 @@
       ipsec_mk_key (&integ_key,
 		    args->local_integ_key, args->local_integ_key_len);
 
-      rv = ipsec_sa_add (ipsec_tun_mk_output_sa_id (dev_instance),
-			 args->local_spi,
-			 IPSEC_PROTOCOL_ESP,
-			 args->crypto_alg,
-			 &crypto_key,
-			 args->integ_alg,
-			 &integ_key,
-			 flags,
-			 args->tx_table_id,
-			 args->salt,
-			 &args->local_ip,
-			 &args->remote_ip, &t->output_sa_index);
+      rv = ipsec_sa_add_and_lock (ipsec_tun_mk_output_sa_id (dev_instance),
+				  args->local_spi,
+				  IPSEC_PROTOCOL_ESP,
+				  args->crypto_alg,
+				  &crypto_key,
+				  args->integ_alg,
+				  &integ_key,
+				  flags,
+				  args->tx_table_id,
+				  args->salt,
+				  &args->local_ip,
+				  &args->remote_ip, &t->output_sa_index);
 
       if (rv)
 	return rv;
@@ -420,11 +420,11 @@
       hash_unset (im->ipsec_if_real_dev_by_show_dev, t->show_instance);
       im->ipsec_if_by_sw_if_index[t->sw_if_index] = ~0;
 
-      pool_put (im->tunnel_interfaces, t);
-
       /* delete input and output SA */
-      ipsec_sa_del (ipsec_tun_mk_input_sa_id (ti));
-      ipsec_sa_del (ipsec_tun_mk_output_sa_id (ti));
+      ipsec_sa_unlock (t->input_sa_index);
+      ipsec_sa_unlock (t->output_sa_index);
+
+      pool_put (im->tunnel_interfaces, t);
     }
 
   if (sw_if_index)
@@ -447,17 +447,12 @@
   hi = vnet_get_hw_interface (vnm, hw_if_index);
   t = pool_elt_at_index (im->tunnel_interfaces, hi->dev_instance);
 
-  sa_index = ipsec_get_sa_index_by_sa_id (sa_id);
-  if (sa_index == ~0)
+  sa_index = ipsec_sa_find_and_lock (sa_id);
+
+  if (INDEX_INVALID == sa_index)
     {
       clib_warning ("SA with ID %u not found", sa_id);
-      return VNET_API_ERROR_INVALID_VALUE;
-    }
-
-  if (ipsec_is_sa_used (sa_index))
-    {
-      clib_warning ("SA with ID %u is already in use", sa_id);
-      return VNET_API_ERROR_INVALID_VALUE;
+      return VNET_API_ERROR_NO_SUCH_ENTRY;
     }
 
   sa = pool_elt_at_index (im->sad, sa_index);
@@ -537,15 +532,15 @@
     }
 
   /* remove sa_id to sa_index mapping on old SA */
-  if (ipsec_get_sa_index_by_sa_id (old_sa->id) == old_sa_index)
-    hash_unset (im->sa_index_by_sa_id, old_sa->id);
+  hash_unset (im->sa_index_by_sa_id, old_sa->id);
 
   if (ipsec_add_del_sa_sess_cb (im, old_sa_index, 0))
     {
       clib_warning ("IPsec backend add/del callback returned error");
       return VNET_API_ERROR_SYSCALL_ERROR_1;
     }
-  ipsec_sa_del (old_sa->id);
+
+  ipsec_sa_unlock (old_sa_index);
 
   return 0;
 }
diff --git a/src/vnet/ipsec/ipsec_sa.c b/src/vnet/ipsec/ipsec_sa.c
index afdecfe..e3eff58 100644
--- a/src/vnet/ipsec/ipsec_sa.c
+++ b/src/vnet/ipsec/ipsec_sa.c
@@ -123,18 +123,18 @@
 }
 
 int
-ipsec_sa_add (u32 id,
-	      u32 spi,
-	      ipsec_protocol_t proto,
-	      ipsec_crypto_alg_t crypto_alg,
-	      const ipsec_key_t * ck,
-	      ipsec_integ_alg_t integ_alg,
-	      const ipsec_key_t * ik,
-	      ipsec_sa_flags_t flags,
-	      u32 tx_table_id,
-	      u32 salt,
-	      const ip46_address_t * tun_src,
-	      const ip46_address_t * tun_dst, u32 * sa_out_index)
+ipsec_sa_add_and_lock (u32 id,
+		       u32 spi,
+		       ipsec_protocol_t proto,
+		       ipsec_crypto_alg_t crypto_alg,
+		       const ipsec_key_t * ck,
+		       ipsec_integ_alg_t integ_alg,
+		       const ipsec_key_t * ik,
+		       ipsec_sa_flags_t flags,
+		       u32 tx_table_id,
+		       u32 salt,
+		       const ip46_address_t * tun_src,
+		       const ip46_address_t * tun_dst, u32 * sa_out_index)
 {
   vlib_main_t *vm = vlib_get_main ();
   ipsec_main_t *im = &ipsec_main;
@@ -150,6 +150,7 @@
   pool_get_aligned_zero (im->sad, sa, CLIB_CACHE_LINE_BYTES);
 
   fib_node_init (&sa->node, FIB_NODE_TYPE_IPSEC_SA);
+  fib_node_lock (&sa->node);
   sa_index = sa - im->sad;
 
   vlib_validate_combined_counter (&ipsec_sa_counters, sa_index);
@@ -272,33 +273,18 @@
   return (0);
 }
 
-u32
-ipsec_sa_del (u32 id)
+static void
+ipsec_sa_del (ipsec_sa_t * sa)
 {
   vlib_main_t *vm = vlib_get_main ();
   ipsec_main_t *im = &ipsec_main;
-  ipsec_sa_t *sa = 0;
-  uword *p;
   u32 sa_index;
-  clib_error_t *err;
 
-  p = hash_get (im->sa_index_by_sa_id, id);
-
-  if (!p)
-    return VNET_API_ERROR_NO_SUCH_ENTRY;
-
-  sa_index = p[0];
-  sa = pool_elt_at_index (im->sad, sa_index);
-  if (ipsec_is_sa_used (sa_index))
-    {
-      clib_warning ("sa_id %u used in policy", sa->id);
-      /* sa used in policy */
-      return VNET_API_ERROR_RSRC_IN_USE;
-    }
+  sa_index = sa - im->sad;
   hash_unset (im->sa_index_by_sa_id, sa->id);
-  err = ipsec_call_add_del_callbacks (im, sa, sa_index, 0);
-  if (err)
-    return VNET_API_ERROR_SYSCALL_ERROR_2;
+
+  /* no recovery possible when deleting an SA */
+  (void) ipsec_call_add_del_callbacks (im, sa, sa_index, 0);
 
   if (ipsec_sa_is_set_IS_TUNNEL (sa) && !ipsec_sa_is_set_IS_INBOUND (sa))
     {
@@ -311,7 +297,55 @@
   vnet_crypto_key_del (vm, sa->crypto_key_index);
   vnet_crypto_key_del (vm, sa->integ_key_index);
   pool_put (im->sad, sa);
-  return 0;
+}
+
+void
+ipsec_sa_unlock (index_t sai)
+{
+  ipsec_main_t *im = &ipsec_main;
+  ipsec_sa_t *sa;
+
+  if (INDEX_INVALID == sai)
+    return;
+
+  sa = pool_elt_at_index (im->sad, sai);
+
+  fib_node_unlock (&sa->node);
+}
+
+index_t
+ipsec_sa_find_and_lock (u32 id)
+{
+  ipsec_main_t *im = &ipsec_main;
+  ipsec_sa_t *sa;
+  uword *p;
+
+  p = hash_get (im->sa_index_by_sa_id, id);
+
+  if (!p)
+    return INDEX_INVALID;
+
+  sa = pool_elt_at_index (im->sad, p[0]);
+
+  fib_node_lock (&sa->node);
+
+  return (p[0]);
+}
+
+int
+ipsec_sa_unlock_id (u32 id)
+{
+  ipsec_main_t *im = &ipsec_main;
+  uword *p;
+
+  p = hash_get (im->sa_index_by_sa_id, id);
+
+  if (!p)
+    return VNET_API_ERROR_NO_SUCH_ENTRY;
+
+  ipsec_sa_unlock (p[0]);
+
+  return (0);
 }
 
 void
@@ -320,58 +354,6 @@
   vlib_zero_combined_counter (&ipsec_sa_counters, sai);
 }
 
-u8
-ipsec_is_sa_used (u32 sa_index)
-{
-  ipsec_main_t *im = &ipsec_main;
-  ipsec_tun_protect_t *itp;
-  ipsec_tunnel_if_t *t;
-  ipsec_policy_t *p;
-  u32 sai;
-
-  /* *INDENT-OFF* */
-  pool_foreach(p, im->policies, ({
-     if (p->policy == IPSEC_POLICY_ACTION_PROTECT)
-       {
-         if (p->sa_index == sa_index)
-           return 1;
-       }
-  }));
-
-  pool_foreach(t, im->tunnel_interfaces, ({
-    if (t->input_sa_index == sa_index)
-      return 1;
-    if (t->output_sa_index == sa_index)
-      return 1;
-  }));
-
-  /* *INDENT-OFF* */
-  pool_foreach(itp, ipsec_protect_pool, ({
-    FOR_EACH_IPSEC_PROTECT_INPUT_SAI(itp, sai,
-    ({
-      if (sai == sa_index)
-        return 1;
-    }));
-    if (itp->itp_out_sa == sa_index)
-      return 1;
-  }));
-  /* *INDENT-ON* */
-
-
-  return 0;
-}
-
-u32
-ipsec_get_sa_index_by_sa_id (u32 sa_id)
-{
-  ipsec_main_t *im = &ipsec_main;
-  uword *p = hash_get (im->sa_index_by_sa_id, sa_id);
-  if (!p)
-    return ~0;
-
-  return p[0];
-}
-
 void
 ipsec_sa_walk (ipsec_sa_walk_cb_t cb, void *ctx)
 {
@@ -402,6 +384,15 @@
   return (&sa->node);
 }
 
+static ipsec_sa_t *
+ipsec_sa_from_fib_node (fib_node_t * node)
+{
+  ASSERT (FIB_NODE_TYPE_IPSEC_SA == node->fn_type);
+  return ((ipsec_sa_t *) (((char *) node) -
+			  STRUCT_OFFSET_OF (ipsec_sa_t, node)));
+
+}
+
 /**
  * Function definition to inform the FIB node that its last lock has gone.
  */
@@ -412,16 +403,7 @@
    * The ipsec SA is a root of the graph. As such
    * it never has children and thus is never locked.
    */
-  ASSERT (0);
-}
-
-static ipsec_sa_t *
-ipsec_sa_from_fib_node (fib_node_t * node)
-{
-  ASSERT (FIB_NODE_TYPE_IPSEC_SA == node->fn_type);
-  return ((ipsec_sa_t *) (((char *) node) -
-			  STRUCT_OFFSET_OF (ipsec_sa_t, node)));
-
+  ipsec_sa_del (ipsec_sa_from_fib_node (node));
 }
 
 /**
diff --git a/src/vnet/ipsec/ipsec_sa.h b/src/vnet/ipsec/ipsec_sa.h
index 2848267..811f4ca 100644
--- a/src/vnet/ipsec/ipsec_sa.h
+++ b/src/vnet/ipsec/ipsec_sa.h
@@ -140,7 +140,6 @@
   };
   udp_header_t udp_hdr;
 
-
   fib_node_t node;
   u32 id;
   u32 stat_index;
@@ -198,29 +197,28 @@
 
 extern void ipsec_mk_key (ipsec_key_t * key, const u8 * data, u8 len);
 
-extern int ipsec_sa_add (u32 id,
-			 u32 spi,
-			 ipsec_protocol_t proto,
-			 ipsec_crypto_alg_t crypto_alg,
-			 const ipsec_key_t * ck,
-			 ipsec_integ_alg_t integ_alg,
-			 const ipsec_key_t * ik,
-			 ipsec_sa_flags_t flags,
-			 u32 tx_table_id,
-			 u32 salt,
-			 const ip46_address_t * tunnel_src_addr,
-			 const ip46_address_t * tunnel_dst_addr,
-			 u32 * sa_index);
-extern u32 ipsec_sa_del (u32 id);
+extern int ipsec_sa_add_and_lock (u32 id,
+				  u32 spi,
+				  ipsec_protocol_t proto,
+				  ipsec_crypto_alg_t crypto_alg,
+				  const ipsec_key_t * ck,
+				  ipsec_integ_alg_t integ_alg,
+				  const ipsec_key_t * ik,
+				  ipsec_sa_flags_t flags,
+				  u32 tx_table_id,
+				  u32 salt,
+				  const ip46_address_t * tunnel_src_addr,
+				  const ip46_address_t * tunnel_dst_addr,
+				  u32 * sa_index);
+extern index_t ipsec_sa_find_and_lock (u32 id);
+extern int ipsec_sa_unlock_id (u32 id);
+extern void ipsec_sa_unlock (index_t sai);
 extern void ipsec_sa_clear (index_t sai);
 extern void ipsec_sa_set_crypto_alg (ipsec_sa_t * sa,
 				     ipsec_crypto_alg_t crypto_alg);
 extern void ipsec_sa_set_integ_alg (ipsec_sa_t * sa,
 				    ipsec_integ_alg_t integ_alg);
 
-extern u8 ipsec_is_sa_used (u32 sa_index);
-extern u32 ipsec_get_sa_index_by_sa_id (u32 sa_id);
-
 typedef walk_rc_t (*ipsec_sa_walk_cb_t) (ipsec_sa_t * sa, void *ctx);
 extern void ipsec_sa_walk (ipsec_sa_walk_cb_t cd, void *ctx);
 
diff --git a/src/vnet/ipsec/ipsec_spd_policy.c b/src/vnet/ipsec/ipsec_spd_policy.c
index 34b7dc2..6424210 100644
--- a/src/vnet/ipsec/ipsec_spd_policy.c
+++ b/src/vnet/ipsec/ipsec_spd_policy.c
@@ -142,14 +142,6 @@
   u32 spd_index;
   uword *p;
 
-  if (policy->policy == IPSEC_POLICY_ACTION_PROTECT)
-    {
-      p = hash_get (im->sa_index_by_sa_id, policy->sa_id);
-      if (!p)
-	return VNET_API_ERROR_SYSCALL_ERROR_1;
-      policy->sa_index = p[0];
-    }
-
   p = hash_get (im->spd_index_by_spd_id, policy->id);
 
   if (!p)
@@ -164,6 +156,17 @@
     {
       u32 policy_index;
 
+      if (policy->policy == IPSEC_POLICY_ACTION_PROTECT)
+	{
+	  index_t sa_index = ipsec_sa_find_and_lock (policy->sa_id);
+
+	  if (INDEX_INVALID == sa_index)
+	    return VNET_API_ERROR_SYSCALL_ERROR_1;
+	  policy->sa_index = sa_index;
+	}
+      else
+	policy->sa_index = INDEX_INVALID;
+
       pool_get (im->policies, vp);
       clib_memcpy (vp, policy, sizeof (*vp));
       policy_index = vp - im->policies;
@@ -188,6 +191,7 @@
 	if (ipsec_policy_is_equal (vp, policy))
 	  {
 	    vec_del1 (spd->policies[policy->type], ii);
+	    ipsec_sa_unlock (vp->sa_index);
 	    pool_put (im->policies, vp);
 	    break;
 	  }
diff --git a/src/vnet/ipsec/ipsec_tun.c b/src/vnet/ipsec/ipsec_tun.c
index a389cef..46980df 100644
--- a/src/vnet/ipsec/ipsec_tun.c
+++ b/src/vnet/ipsec/ipsec_tun.c
@@ -191,6 +191,7 @@
 ipsec_tun_protect_unconfig (ipsec_main_t * im, ipsec_tun_protect_t * itp)
 {
   ipsec_sa_t *sa;
+  index_t sai;
 
   ipsec_tun_protect_feature_set (itp, 0);
 
@@ -199,9 +200,16 @@
   ({
     ipsec_sa_unset_IS_PROTECT (sa);
   }));
-  /* *INDENT-ON* */
 
   ipsec_tun_protect_db_remove (im, itp);
+
+  ipsec_sa_unlock(itp->itp_out_sa);
+
+  FOR_EACH_IPSEC_PROTECT_INPUT_SAI(itp, sai,
+  ({
+    ipsec_sa_unlock(sai);
+  }));
+  /* *INDENT-ON* */
 }
 
 index_t
@@ -229,7 +237,7 @@
 
   vec_foreach_index (ii, sas_in)
   {
-    sas_in[ii] = ipsec_get_sa_index_by_sa_id (sas_in[ii]);
+    sas_in[ii] = ipsec_sa_find_and_lock (sas_in[ii]);
     if (~0 == sas_in[ii])
       {
 	rv = VNET_API_ERROR_INVALID_VALUE;
@@ -237,7 +245,7 @@
       }
   }
 
-  sa_out = ipsec_get_sa_index_by_sa_id (sa_out);
+  sa_out = ipsec_sa_find_and_lock (sa_out);
 
   if (~0 == sa_out)
     {
diff --git a/src/vnet/ipsec/ipsec_tun.h b/src/vnet/ipsec/ipsec_tun.h
index be5cef9..2041cbe 100644
--- a/src/vnet/ipsec/ipsec_tun.h
+++ b/src/vnet/ipsec/ipsec_tun.h
@@ -32,12 +32,12 @@
 typedef struct ipsec_tun_protect_t_
 {
   CLIB_CACHE_LINE_ALIGN_MARK (cacheline0);
-  u32 itp_out_sa;
+  index_t itp_out_sa;
 
   /* not using a vector since we want the memory inline
    * with this struct */
   u32 itp_n_sa_in;
-  u32 itp_in_sas[4];
+  index_t itp_in_sas[4];
 
   u32 itp_sw_if_index;