ipsec: anti-replay code cleanup

Change-Id: Ib73352d6be26d639a7f9d47ca0570a1248bff04a
Signed-off-by: Damjan Marion <damarion@cisco.com>
diff --git a/src/plugins/dpdk/ipsec/esp_decrypt.c b/src/plugins/dpdk/ipsec/esp_decrypt.c
index dcc276f..349f04c 100644
--- a/src/plugins/dpdk/ipsec/esp_decrypt.c
+++ b/src/plugins/dpdk/ipsec/esp_decrypt.c
@@ -140,7 +140,7 @@
       while (n_left_from > 0 && n_left_to_next > 0)
 	{
 	  clib_error_t *error;
-	  u32 bi0, sa_index0, seq, iv_size;
+	  u32 bi0, sa_index0, iv_size;
 	  u8 trunc_size;
 	  vlib_buffer_t *b0;
 	  esp_header_t *esp0;
@@ -234,33 +234,21 @@
 	    }
 
 	  /* anti-replay check */
-	  if (ipsec_sa_is_set_USE_ANTI_REPLAY (sa0))
+	  if (ipsec_sa_anti_replay_check (sa0, &esp0->seq))
 	    {
-	      int rv = 0;
-
-	      seq = clib_net_to_host_u32 (esp0->seq);
-
-	      if (PREDICT_TRUE (ipsec_sa_is_set_USE_EXTENDED_SEQ_NUM (sa0)))
-		rv = esp_replay_check_esn (sa0, seq);
+	      clib_warning ("failed anti-replay check");
+	      if (is_ip6)
+		vlib_node_increment_counter (vm,
+					     dpdk_esp6_decrypt_node.index,
+					     ESP_DECRYPT_ERROR_REPLAY, 1);
 	      else
-		rv = esp_replay_check (sa0, seq);
-
-	      if (PREDICT_FALSE (rv))
-		{
-		  clib_warning ("failed anti-replay check");
-		  if (is_ip6)
-		    vlib_node_increment_counter (vm,
-						 dpdk_esp6_decrypt_node.index,
-						 ESP_DECRYPT_ERROR_REPLAY, 1);
-		  else
-		    vlib_node_increment_counter (vm,
-						 dpdk_esp4_decrypt_node.index,
-						 ESP_DECRYPT_ERROR_REPLAY, 1);
-		  to_next[0] = bi0;
-		  to_next += 1;
-		  n_left_to_next -= 1;
-		  goto trace;
-		}
+		vlib_node_increment_counter (vm,
+					     dpdk_esp4_decrypt_node.index,
+					     ESP_DECRYPT_ERROR_REPLAY, 1);
+	      to_next[0] = bi0;
+	      to_next += 1;
+	      n_left_to_next -= 1;
+	      goto trace;
 	    }
 
 	  if (is_ip6)
@@ -560,15 +548,7 @@
 
 	  iv_size = cipher_alg->iv_len;
 
-	  if (ipsec_sa_is_set_USE_ANTI_REPLAY (sa0))
-	    {
-	      u32 seq;
-	      seq = clib_host_to_net_u32 (esp0->seq);
-	      if (PREDICT_TRUE (ipsec_sa_is_set_USE_EXTENDED_SEQ_NUM (sa0)))
-		esp_replay_advance_esn (sa0, seq);
-	      else
-		esp_replay_advance (sa0, seq);
-	    }
+	  ipsec_sa_anti_replay_advance (sa0, &esp0->seq);
 
 	  /* if UDP encapsulation is used adjust the address of the IP header */
 	  if (ipsec_sa_is_set_UDP_ENCAP (sa0)
diff --git a/src/vnet/ipsec/ah_decrypt.c b/src/vnet/ipsec/ah_decrypt.c
index 87e1de1..cf95588 100644
--- a/src/vnet/ipsec/ah_decrypt.c
+++ b/src/vnet/ipsec/ah_decrypt.c
@@ -151,20 +151,10 @@
 	  seq = clib_host_to_net_u32 (ah0->seq_no);
 
 	  /* anti-replay check */
-	  if (ipsec_sa_is_set_USE_ANTI_REPLAY (sa0))
+	  if (ipsec_sa_anti_replay_check (sa0, &ah0->seq_no))
 	    {
-	      int rv = 0;
-
-	      if (PREDICT_TRUE (ipsec_sa_is_set_USE_EXTENDED_SEQ_NUM (sa0)))
-		rv = esp_replay_check_esn (sa0, seq);
-	      else
-		rv = esp_replay_check (sa0, seq);
-
-	      if (PREDICT_FALSE (rv))
-		{
-		  i_b0->error = node->errors[AH_DECRYPT_ERROR_REPLAY];
-		  goto trace;
-		}
+	      i_b0->error = node->errors[AH_DECRYPT_ERROR_REPLAY];
+	      goto trace;
 	    }
 
 	  vlib_increment_combined_counter
@@ -210,15 +200,7 @@
 		  goto trace;
 		}
 
-	      if (PREDICT_TRUE (ipsec_sa_is_set_USE_ANTI_REPLAY (sa0)))
-		{
-		  if (PREDICT_TRUE
-		      (ipsec_sa_is_set_USE_EXTENDED_SEQ_NUM (sa0)))
-		    esp_replay_advance_esn (sa0, seq);
-		  else
-		    esp_replay_advance (sa0, seq);
-		}
-
+	      ipsec_sa_anti_replay_advance (sa0, &ah0->seq_no);
 	    }
 
 	  vlib_buffer_advance (i_b0,
diff --git a/src/vnet/ipsec/esp.h b/src/vnet/ipsec/esp.h
index 1f894ab..cc12785 100644
--- a/src/vnet/ipsec/esp.h
+++ b/src/vnet/ipsec/esp.h
@@ -54,133 +54,13 @@
 }) ip6_and_esp_header_t;
 /* *INDENT-ON* */
 
-#define ESP_WINDOW_SIZE		(64)
 #define ESP_SEQ_MAX		(4294967295UL)
 #define ESP_MAX_BLOCK_SIZE	(16)
 #define ESP_MAX_ICV_SIZE	(16)
 
 u8 *format_esp_header (u8 * s, va_list * args);
 
-always_inline int
-esp_replay_check (ipsec_sa_t * sa, u32 seq)
-{
-  u32 diff;
-
-  if (PREDICT_TRUE (seq > sa->last_seq))
-    return 0;
-
-  diff = sa->last_seq - seq;
-
-  if (ESP_WINDOW_SIZE > diff)
-    return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
-  else
-    return 1;
-
-  return 0;
-}
-
-always_inline int
-esp_replay_check_esn (ipsec_sa_t * sa, u32 seq)
-{
-  u32 tl = sa->last_seq;
-  u32 th = sa->last_seq_hi;
-  u32 diff = tl - seq;
-
-  if (PREDICT_TRUE (tl >= (ESP_WINDOW_SIZE - 1)))
-    {
-      if (seq >= (tl - ESP_WINDOW_SIZE + 1))
-	{
-	  sa->seq_hi = th;
-	  if (seq <= tl)
-	    return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
-	  else
-	    return 0;
-	}
-      else
-	{
-	  sa->seq_hi = th + 1;
-	  return 0;
-	}
-    }
-  else
-    {
-      if (seq >= (tl - ESP_WINDOW_SIZE + 1))
-	{
-	  sa->seq_hi = th - 1;
-	  return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
-	}
-      else
-	{
-	  sa->seq_hi = th;
-	  if (seq <= tl)
-	    return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
-	  else
-	    return 0;
-	}
-    }
-
-  return 0;
-}
-
 /* TODO seq increment should be atomic to be accessed by multiple workers */
-always_inline void
-esp_replay_advance (ipsec_sa_t * sa, u32 seq)
-{
-  u32 pos;
-
-  if (seq > sa->last_seq)
-    {
-      pos = seq - sa->last_seq;
-      if (pos < ESP_WINDOW_SIZE)
-	sa->replay_window = ((sa->replay_window) << pos) | 1;
-      else
-	sa->replay_window = 1;
-      sa->last_seq = seq;
-    }
-  else
-    {
-      pos = sa->last_seq - seq;
-      sa->replay_window |= (1ULL << pos);
-    }
-}
-
-always_inline void
-esp_replay_advance_esn (ipsec_sa_t * sa, u32 seq)
-{
-  int wrap = sa->seq_hi - sa->last_seq_hi;
-  u32 pos;
-
-  if (wrap == 0 && seq > sa->last_seq)
-    {
-      pos = seq - sa->last_seq;
-      if (pos < ESP_WINDOW_SIZE)
-	sa->replay_window = ((sa->replay_window) << pos) | 1;
-      else
-	sa->replay_window = 1;
-      sa->last_seq = seq;
-    }
-  else if (wrap > 0)
-    {
-      pos = ~seq + sa->last_seq + 1;
-      if (pos < ESP_WINDOW_SIZE)
-	sa->replay_window = ((sa->replay_window) << pos) | 1;
-      else
-	sa->replay_window = 1;
-      sa->last_seq = seq;
-      sa->last_seq_hi = sa->seq_hi;
-    }
-  else if (wrap < 0)
-    {
-      pos = ~seq + sa->last_seq + 1;
-      sa->replay_window |= (1ULL << pos);
-    }
-  else
-    {
-      pos = sa->last_seq - seq;
-      sa->replay_window |= (1ULL << pos);
-    }
-}
-
 always_inline int
 esp_seq_advance (ipsec_sa_t * sa)
 {
diff --git a/src/vnet/ipsec/esp_decrypt.c b/src/vnet/ipsec/esp_decrypt.c
index 7f3c320..9366619 100644
--- a/src/vnet/ipsec/esp_decrypt.c
+++ b/src/vnet/ipsec/esp_decrypt.c
@@ -134,7 +134,6 @@
       esp_header_t *esp0;
       ipsec_sa_t *sa0;
       u32 sa_index0 = ~0;
-      u32 seq;
       ip4_header_t *ih4 = 0, *oh4 = 0;
       ip6_header_t *ih6 = 0, *oh6 = 0;
       u8 tunnel_mode = 1;
@@ -144,29 +143,18 @@
       esp0 = vlib_buffer_get_current (ib[0]);
       sa_index0 = vnet_buffer (ib[0])->ipsec.sad_index;
       sa0 = pool_elt_at_index (im->sad, sa_index0);
-      seq = clib_host_to_net_u32 (esp0->seq);
 
       /* anti-replay check */
-      if (ipsec_sa_is_set_USE_ANTI_REPLAY (sa0))
+      if (ipsec_sa_anti_replay_check (sa0, &esp0->seq))
 	{
-	  int rv = 0;
-
-	  if (PREDICT_TRUE (ipsec_sa_is_set_USE_EXTENDED_SEQ_NUM (sa0)))
-	    rv = esp_replay_check_esn (sa0, seq);
-	  else
-	    rv = esp_replay_check (sa0, seq);
-
-	  if (PREDICT_FALSE (rv))
-	    {
-	      u32 tmp, off = n_alloc - n_left_from;
-	      /* send original packet to drop node */
-	      tmp = from[off];
-	      from[off] = new_bufs[off];
-	      new_bufs[off] = tmp;
-	      ib[0]->error = node->errors[ESP_DECRYPT_ERROR_REPLAY];
-	      next[0] = ESP_DECRYPT_NEXT_DROP;
-	      goto trace;
-	    }
+	  u32 tmp, off = n_alloc - n_left_from;
+	  /* send original packet to drop node */
+	  tmp = from[off];
+	  from[off] = new_bufs[off];
+	  new_bufs[off] = tmp;
+	  ib[0]->error = node->errors[ESP_DECRYPT_ERROR_REPLAY];
+	  next[0] = ESP_DECRYPT_NEXT_DROP;
+	  goto trace;
 	}
 
       vlib_increment_combined_counter
@@ -197,13 +185,7 @@
 	    }
 	}
 
-      if (PREDICT_TRUE (ipsec_sa_is_set_USE_ANTI_REPLAY (sa0)))
-	{
-	  if (PREDICT_TRUE (ipsec_sa_is_set_USE_EXTENDED_SEQ_NUM (sa0)))
-	    esp_replay_advance_esn (sa0, seq);
-	  else
-	    esp_replay_advance (sa0, seq);
-	}
+      ipsec_sa_anti_replay_advance (sa0, &esp0->seq);
 
       if ((sa0->crypto_alg >= IPSEC_CRYPTO_ALG_AES_CBC_128 &&
 	   sa0->crypto_alg <= IPSEC_CRYPTO_ALG_AES_CBC_256) ||
diff --git a/src/vnet/ipsec/ipsec_sa.h b/src/vnet/ipsec/ipsec_sa.h
index 1cd2153..44f9642 100644
--- a/src/vnet/ipsec/ipsec_sa.h
+++ b/src/vnet/ipsec/ipsec_sa.h
@@ -19,6 +19,8 @@
 #include <vnet/ip/ip.h>
 #include <vnet/fib/fib_node.h>
 
+#define IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE (64)
+
 #define foreach_ipsec_crypto_alg    \
   _ (0, NONE, "none")               \
   _ (1, AES_CBC_128, "aes-cbc-128") \
@@ -94,7 +96,7 @@
 #define _(v, f, s) IPSEC_SA_FLAG_##f = v,
   foreach_ipsec_sa_flags
 #undef _
-} __attribute__ ((packed)) ipsec_sa_flags_t;
+} __clib_packed ipsec_sa_flags_t;
 
 STATIC_ASSERT (sizeof (ipsec_sa_flags_t) == 1, "IPSEC SA flags > 1 byte");
 
@@ -216,6 +218,132 @@
 				       va_list * args);
 extern uword unformat_ipsec_key (unformat_input_t * input, va_list * args);
 
+always_inline int
+ipsec_sa_anti_replay_check (ipsec_sa_t * sa, u32 * seqp)
+{
+  u32 seq, diff, tl, th;
+  if ((sa->flags & IPSEC_SA_FLAG_USE_ANTI_REPLAY) == 0)
+    return 0;
+
+  seq = clib_net_to_host_u32 (*seqp);
+
+  if ((sa->flags & IPSEC_SA_FLAG_USE_EXTENDED_SEQ_NUM) == 0)
+    {
+
+      if (PREDICT_TRUE (seq > sa->last_seq))
+	return 0;
+
+      diff = sa->last_seq - seq;
+
+      if (IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE > diff)
+	return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
+      else
+	return 1;
+
+      return 0;
+    }
+
+  tl = sa->last_seq;
+  th = sa->last_seq_hi;
+  diff = tl - seq;
+
+  if (PREDICT_TRUE (tl >= (IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE - 1)))
+    {
+      if (seq >= (tl - IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE + 1))
+	{
+	  sa->seq_hi = th;
+	  if (seq <= tl)
+	    return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
+	  else
+	    return 0;
+	}
+      else
+	{
+	  sa->seq_hi = th + 1;
+	  return 0;
+	}
+    }
+  else
+    {
+      if (seq >= (tl - IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE + 1))
+	{
+	  sa->seq_hi = th - 1;
+	  return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
+	}
+      else
+	{
+	  sa->seq_hi = th;
+	  if (seq <= tl)
+	    return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
+	  else
+	    return 0;
+	}
+    }
+
+  return 0;
+}
+
+always_inline void
+ipsec_sa_anti_replay_advance (ipsec_sa_t * sa, u32 * seqp)
+{
+  u32 pos, seq;
+  if (PREDICT_TRUE (sa->flags & IPSEC_SA_FLAG_USE_ANTI_REPLAY) == 0)
+    return;
+
+  seq = clib_host_to_net_u32 (*seqp);
+  if (PREDICT_TRUE (sa->flags & IPSEC_SA_FLAG_USE_EXTENDED_SEQ_NUM))
+    {
+      int wrap = sa->seq_hi - sa->last_seq_hi;
+
+      if (wrap == 0 && seq > sa->last_seq)
+	{
+	  pos = seq - sa->last_seq;
+	  if (pos < IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE)
+	    sa->replay_window = ((sa->replay_window) << pos) | 1;
+	  else
+	    sa->replay_window = 1;
+	  sa->last_seq = seq;
+	}
+      else if (wrap > 0)
+	{
+	  pos = ~seq + sa->last_seq + 1;
+	  if (pos < IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE)
+	    sa->replay_window = ((sa->replay_window) << pos) | 1;
+	  else
+	    sa->replay_window = 1;
+	  sa->last_seq = seq;
+	  sa->last_seq_hi = sa->seq_hi;
+	}
+      else if (wrap < 0)
+	{
+	  pos = ~seq + sa->last_seq + 1;
+	  sa->replay_window |= (1ULL << pos);
+	}
+      else
+	{
+	  pos = sa->last_seq - seq;
+	  sa->replay_window |= (1ULL << pos);
+	}
+    }
+  else
+    {
+      if (seq > sa->last_seq)
+	{
+	  pos = seq - sa->last_seq;
+	  if (pos < IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE)
+	    sa->replay_window = ((sa->replay_window) << pos) | 1;
+	  else
+	    sa->replay_window = 1;
+	  sa->last_seq = seq;
+	}
+      else
+	{
+	  pos = sa->last_seq - seq;
+	  sa->replay_window |= (1ULL << pos);
+	}
+    }
+}
+
 #endif /* __IPSEC_SPD_SA_H__ */
 
 /*